#!/usr/bin/python
# @package      hubzero-sss-ldap
# @file         hzsss
# @author       Nicholas J. Kisseberth <nkissebe@purdue.edu>
# @copyright    Copyright (c) 2012-2018 HUBzero Foundation, LLC.
# @license      http://opensource.org/licenses/MIT MIT
#
# Copyright (c) 2012-2018 HUBzero Foundation, LLC.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# HUBzero is a registered trademark of HUBzero Foundation, LLC.
#
import argparse
import base64
import ConfigParser
import hashlib
import hubzero.data.db
import hubzero.utilities.misc
import hubzero.config.hubzerositeconfig
import os
import re
import subprocess
import sys
import traceback
from distutils.version import LooseVersion, StrictVersion

def serviceInit(service, action = 'start'):

    if service == 'apache2':
        if (not os.path.exists('/etc/init.d/apache2') and not os.path.exists('/usr/lib/systemd/system/apache2.service')) and (os.path.exists('/etc/init.d/httpd') or os.path.exists('/usr/lib/systemd/system/httpd.service')):
            service = 'httpd'

    if service == 'httpd':
        if (not os.path.exists('/etc/init.d/httpd') and not os.path.exists('/usr/lib/systemd/system/httpd.service')) and (os.path.exists('/etc/init.d/apache2') or os.path.exists('/usr/lib/systemd/system/apache2.service')):
            service = 'apache2'

    if service == 'mysql':
        if (not os.path.exists('/etc/init.d/mysql') and not os.path.exists('/usr/lib/systemd/system/mysql.service')) and (os.path.exists('/etc/init.d/mysqld') or os.path.exists('/usr/lib/systemd/system/mysqld.service')):
            service = 'mysqld'

    if service == 'mysqld':
        if (not os.path.exists('/etc/init.d/mysqld') and not os.path.exists('/usr/lib/systemd/system/mysqld.service')) and (os.path.exists('/etc/init.d/mysql') or os.path.exists('/usr/lib/systemd/system/mysql.service')):
            service = 'mysql'

    if service == 'crond':
        if (not os.path.exists('/etc/init.d/crond') and not os.path.exists('/usr/lib/systemd/system/crond.service')) and (os.path.exists('/etc/init.d/cron') or os.path.exists('/usr/lib/systemd/system/cron.service')):
            service = 'cron'

    if service == 'cron':
        if (not os.path.exists('/etc/init.d/cron') and not os.path.exists('/usr/lib/systemd/system/cron.service')) and (os.path.exists('/etc/init.d/crond') or os.path.exists('/usr/lib/systemd/system/crond.service')):
            service = 'crond'

    try:
        print '/sbin/service ' + service + " " + action 
        rc, procStdOut, procStdErr = hubzero.utilities.misc.exShellCommand(['/sbin/service', service, action])
    except:
        print '/sbin/service ' + service + " " + action + " failed"
        return -1

    if rc : print procStdErr

    return rc


def replace(path, regexp, replace = None):

    if replace == None or replace == False:
        replace = ''

    if os.path.exists(path):
        with open(path, 'r+') as f:
            ftxt = f.read()
            txt = ftxt.split("\n")

            for i in range(0,len(txt)):
                txt[i] = re.sub(regexp, replace, txt[i])

            txt = '\n'.join(txt)

            if txt != ftxt:
                f.seek(0)
                f.write(txt)
                f.truncate()
                return 1

            return 0


def lineinfile(path, line, regexp = None, insertbefore = None, insertafter = None, create = False, ifexists = False, state = True):
    """search a file for a line, and ensure that it is present or absent.

    Keyword arguments:
    path            -- the file to modify
    line            -- the line to insert/replace into the file
    regexp          -- the regular expression to look for in the file
    insertbefore    -- the line will be inserted before the last match of specified regular expression
    insertafter     -- the line will be inserted after the last match of specified regular expression
    create          -- if true the file will be created if it does not already exist
    state           -- whether the line should be there (true) or not (false)
    ifexists        -- whether to run only if file exists

    Returns:
    0 - No change required
    1 - Changes made
    """

    if insertbefore != None and insertafter != None:
        raise ValueError('invalid argument combination: insertbefore and insertafter')

    if (insertbefore != None or insertafter != None) and not state:
        raise ValueError('invalid argument combination: insertbefore or insertafter with state == false')

    if (not os.path.exists(path)) and create:
        with open(path, 'w+') as f:
            if line != None and state:
                f.write(line)
        return 1

    if (not os.path.exists(path)) and ifexists:
        return 0

    with open(path, 'r+') as f:
        ftxt = f.read()
        lastchar = ftxt[:-1]
        txt = ftxt.split("\n")
        regexp_match = -1
        insertafter_match = -1
        insertbefore_match = -1
        line_match = -1

        for i in range(0,len(txt)):
            if regexp == None and txt[i] == line and state:
                return 0

            if txt[i] == line:
                line_match = i
            if (regexp != None and re.match(regexp, txt[i])):
                regexp_match = i
            if (insertafter != None and insertafter  != 'EOF' and re.match(insertafter,  txt[i])):
                insertafter_match = i
            if (insertbefore != None and insertbefore != 'BOF' and re.match(insertbefore, txt[i])):
                insertbefore_match= i

        if state:
            if (regexp_match != -1 and insertafter == None and insertbefore == None):
                txt[regexp_match] = line
            elif (regexp_match != -1):
                return 0
            elif (insertbefore == 'BOF'):
                txt.insert(0, line)
            elif (insertafter == 'EOF'):
                txt.append(line)
            elif (insertafter_match != -1):
                txt.insert(insertafter_match + 1, line)
            elif (insertbefore_match != -1):
                txt.insert(insertbefore_match + 1, line)
            elif (insertafter != None):
                txt.append(line)
            elif (insertbefore != None):
                txt.append(line)
            else:
                txt.append(line)
        else:
            if regexp_match != -1:
                txt.pop(regexp_match)
            elif line_match != -1:
                txt.pop(line_match)

        txt = '\n'.join(txt)


        if txt != ftxt:
            f.seek(0)
            f.write(txt)
            if txt[:-1] != ftxt[:-1] and ftxt[:-1] == "\n":
                f.write("\n")
            f.truncate()
            return 1

        return 0

# @TODO:
# --ldap-uri=ldap://myldaphost.org:636
# --ldap-suffix=suffix 
# --ldap-admin-dn=
# --ldap-admin-pw=

def _init_sss(args):

    if not os.path.exists('/etc/hubzero.secrets'):
        print "ERROR: /etc/hubzero.secrets does not exist"
        return 1

    if not os.path.exists('/etc/hubzero.conf'):
        print "ERROR: /etc/hubzero.conf does not exist"
        return 2

    hubzeroSecretsFilename = "/etc/hubzero.secrets"

    secretsConfig = ConfigParser.ConfigParser()
    secretsConfig.optionxform = str

    # read in existing data from hubzero.secrets
    if os.path.exists(hubzeroSecretsFilename):
        f1 = open(hubzeroSecretsFilename, "r")
        secretsConfig.readfp(f1)
        f1.close()

    adminUserPW = secretsConfig.get("DEFAULT", "LDAP-ADMINPW")
    searchUserPW = secretsConfig.get("DEFAULT", "LDAP-SEARCHPW")
    suffix = hubzero.config.hubzerositeconfig.getHubzeroConfigOption('ldap', 'basedn')
    uri = hubzero.config.hubzerositeconfig.getHubzeroConfigOption('ldap', 'uri')
    searchUserDN = hubzero.config.hubzerositeconfig.getHubzeroConfigOption('ldap', 'searchuserdn')
    adminUserDN = hubzero.config.hubzerositeconfig.getHubzeroConfigOption('ldap', 'adminuserdn')

    if os.path.exists('/etc/sssd/'):

        sssdConfFilename = "/etc/sssd/sssd.conf"
        sssdConfConfig = ConfigParser.ConfigParser()
        sssdConfConfig.optionxform = str

        if os.path.exists(sssdConfFilename):
            f1 = open(sssdConfFilename, "r")
            sssdConfConfig.readfp(f1)
            f1.close()

        if not sssdConfConfig.has_section("domain/default"):
            sssdConfConfig.add_section("domain/default")
        sssdConfConfig.set("domain/default", "autofs_provider", 'ldap')
        sssdConfConfig.set("domain/default", "cache_credentials", 'True')
        sssdConfConfig.set("domain/default", "id_provider", 'ldap')
        sssdConfConfig.set("domain/default", "auth_provider", 'ldap')
        sssdConfConfig.set("domain/default", "chpass_provider", 'ldap')
        sssdConfConfig.set("domain/default", "ldap_tls_cacertdir", '/etc/openldap/cacerts')
        if not sssdConfConfig.has_section("sssd"):
            sssdConfConfig.add_section("sssd")
        sssdConfConfig.set("sssd", "services", 'nss, pam, autofs')
        sssdConfConfig.set("sssd", "domains", 'default, ldap')
        sssdConfConfig.set("sssd", "config_file_version", '2')
        if not sssdConfConfig.has_section("nss"):
            sssdConfConfig.add_section("nss")
        sssdConfConfig.set("nss", "homedir_substring", '/home')
        sssdConfConfig.set("nss", "filter_groups", 'root')
        sssdConfConfig.set("nss", "filter_users", 'root')

        if not sssdConfConfig.has_section("domain/LDAP"):
            sssdConfConfig.add_section("domain/LDAP")
        sssdConfConfig.set("domain/LDAP", "enumerate", 'true')
        sssdConfConfig.set("domain/LDAP", "tls_reqcert", 'never')
        sssdConfConfig.set("domain/LDAP", "access_provider", 'ldap')
        sssdConfConfig.set("domain/LDAP", "ldap_user_search_base", "ou=users," + suffix)
        sssdConfConfig.set("domain/LDAP", "ldap_group_search_base", "ou=groups," + suffix)
        sssdConfConfig.set("domain/LDAP", "ldap_default_bind_dn", "cn=admin," + suffix)
        sssdConfConfig.set("domain/LDAP", "ldap_default_authtok_type", 'password')
        sssdConfConfig.set("domain/LDAP", "ldap_default_authtok", adminUserPW)
        sssdConfConfig.set("domain/LDAP", "ldap_id_use_start_tls", 'False')
        sssdConfConfig.set("domain/LDAP", "ldap_auth_disable_tls_never_use_in_production", 'true')
        sssdConfConfig.set("domain/LDAP", "id_provider", 'ldap')
        sssdConfConfig.set("domain/LDAP", "ldap_uri", uri)

        sssdConfConfig.set("domain/LDAP", "ldap_access_filter", 'host=web')

        f2 = os.open(sssdConfFilename, os.O_RDWR | os.O_CREAT, 0600)
        os.fchmod(f2, 0600)
        f3 = os.fdopen(f2,"w")
        sssdConfConfig.write(f3)
        f3.close()

        # These aren't all strictly necessary, but are included for completeness of being equivalent to
        # authconfig call being replaced. They are also very fragile, a much more complex
        # and "intelligent" pam.d config file modification function would be needed to
        # improve this.
        #
        # @TODO: is broken_shadow needed?

        replace(path = '/etc/pam.d/password-auth-ac',         regexp = r'^\s*auth\s+(required|requisite)\s+.*pam_unix\.so.*$',           replace='auth        sufficient    pam_unix.so nullok try_first_pass');
        lineinfile(path = '/etc/pam.d/password-auth-ac',    insertafter = r'^\s*auth\s+.*pam_unix\.so.*$',           line='auth        sufficient    pam_sss.so use_first_pass');
        replace(path = '/etc/pam.d/password-auth-ac',   regexp =r'(^\s*auth\s*.*)(success\s*=\s*1)(.*pam_unix\.so\s+.*$)', replace='\\1success=2\\3')
            
        replace(path = '/etc/pam.d/system-auth-ac',           regexp = r'^\s*auth\s+(required|requisite)\s+.*pam_unix\.so.*$',           replace='auth        sufficient    pam_unix.so nullok try_first_pass');
        lineinfile(path = '/etc/pam.d/system-auth-ac',      insertafter = r'^\s*auth\s+.*pam_unix\.so.*$',           line='auth        sufficient    pam_sss.so use_first_pass');
        replace(path = '/etc/pam.d/system-auth-ac',   regexp =r'(^\s*auth\s*.*)(success\s*=\s*1)(.*pam_unix\.so\s+.*$)', replace='\\1success=2\\3')

        lineinfile(path = '/etc/pam.d/fingerprint-auth-ac',      regexp = r'^\s*account\s+.*pam_unix\.so.*$',        line='account     required      pam_unix.so broken_shadow');
        lineinfile(path = '/etc/pam.d/password-auth-ac',         regexp = r'^\s*account\s+.*pam_unix\.so.*$',        line='account     required      pam_unix.so broken_shadow');
        lineinfile(path = '/etc/pam.d/smartcard-auth-ac',        regexp = r'^\s*account\s+.*pam_unix\.so.*$',        line='account     required      pam_unix.so broken_shadow');
        lineinfile(path = '/etc/pam.d/system-auth-ac',           regexp = r'^\s*account\s+.*pam_unix\.so.*$',        line='account     required      pam_unix.so broken_shadow');

        lineinfile(path = '/etc/pam.d/smartcard-auth-ac',   insertafter = r'^\s*account\s+sufficient\s+pam_succeed', line='account     [default=bad success=ok user_unknown=ignore] pam_sss.so');
        lineinfile(path = '/etc/pam.d/password-auth-ac',    insertafter = r'^\s*account\s+sufficient\s+pam_succeed', line='account     [default=bad success=ok user_unknown=ignore] pam_sss.so');
        lineinfile(path = '/etc/pam.d/system-auth-ac',      insertafter = r'^\s*account\s+sufficient\s+pam_succeed', line='account     [default=bad success=ok user_unknown=ignore] pam_sss.so');
        lineinfile(path = '/etc/pam.d/fingerprint-auth-ac', insertafter = r'^\s*account\s+sufficient\s+pam_succeed', line='account     [default=bad success=ok user_unknown=ignore] pam_sss.so');
        lineinfile(path = '/etc/pam.d/password-auth-ac',    insertafter = r'^\s*password\s+.*pam_unix\.so.*$',       line='password    sufficient    pam_sss.so use_authtok');
        lineinfile(path = '/etc/pam.d/system-auth-ac',      insertafter = r'^\s*password\s+.*pam_unix\.so.*$',       line='password    sufficient    pam_sss.so use_authtok');
        lineinfile(path = '/etc/pam.d/fingerprint-auth-ac', insertafter = r'^\s*session\s+.*pam_unix\.so.*$',        line='session     optional      pam_sss.so');
        lineinfile(path = '/etc/pam.d/password-auth-ac',    insertafter = r'^\s*session\s+.*pam_unix\.so.*$',        line='session     optional      pam_sss.so');
        lineinfile(path = '/etc/pam.d/smartcard-auth-ac',   insertafter = r'^\s*session\s+.*pam_unix\.so.*$',        line='session     optional      pam_sss.so');
        lineinfile(path = '/etc/pam.d/system-auth-ac',      insertafter = r'^\s*session\s+.*pam_unix\.so.*$',        line='session     optional      pam_sss.so');

        lineinfile(path = '/etc/nsswitch.conf', regexp = r'\s*passwd:\s*.*$', line='passwd: files sss')
        lineinfile(path = '/etc/nsswitch.conf', regexp = r'\s*shadow:\s*.*$', line='shadow: files sss')
        lineinfile(path = '/etc/nsswitch.conf', regexp = r'\s*group:\s*.*$', line='group: files sss')
        lineinfile(path = '/etc/nsswitch.conf', regexp = r'\s*services:\s*.*$', line='services: files sss')
        lineinfile(path = '/etc/nsswitch.conf', regexp = r'\s*netgroup:\s*.*$', line='netgroup: files sss')
        lineinfile(path = '/etc/nsswitch.conf', regexp = r'\s*automount:\s*.*$', line='automount: files sss')

        serviceInit('sssd', 'restart')

    if os.path.exists('/etc/nslcd.conf'):

        print "writing out /etc/nslcd.conf"
        lineinfile(path = '/etc/nslcd.conf', regexp = r'^\s*uri\s+.*$', line='uri ' + uri)
        lineinfile(path = '/etc/nslcd.conf', regexp = r'^\s*base\s+dc.*$', line='base ' + suffix)
        lineinfile(path = '/etc/nslcd.conf', regexp = r'^\s*#?\s*binddn\s+.*$', line='binddn ' + searchUserDN)
        lineinfile(path = '/etc/nslcd.conf', regexp = r'^\s*#?\s*bindpw\s+.*$', line='binddn ' + searchUserPW)
        lineinfile(path = '/etc/nslcd.conf', regexp = r'^\s*#?\s*tls_cacertfile\s+.*$', line='tls_cacertfile /etc/ssl/certs/ca-bundle.crt')
        lineinfile(path = '/etc/nslcd.conf', regexp = r'^\s*bind_timelimit\s+.*$', line='bind_timelimit 1')
        lineinfile(path = '/etc/nslcd.conf', regexp = r'^\s*timelimit\s+.*$', line='timelimit 5')
        lineinfile(path = '/etc/nslcd.conf', regexp = r'^\s*reconnect_sleeptime\s+.*$', line='reconnect_sleeptime 0')
        lineinfile(path = '/etc/nslcd.conf', regexp = r'^\s*base\s+passwd\s+.*$', line='base passwd ' + 'ou=users,' + suffix)
        lineinfile(path = '/etc/nslcd.conf', regexp = r'^\s*base\s+shadow\s+.*$', line='base shadow ' + 'ou=users,' + suffix)
        lineinfile(path = '/etc/nslcd.conf', regexp = r'^\s*base\s+group\s+.*$', line='base group ' + 'ou=groups,' + suffix)
        lineinfile(path = '/etc/nslcd.conf', regexp = r'^\s*scope\s+group\s+.*$', line='scope group one')
        lineinfile(path = '/etc/nslcd.conf', regexp = r'^\s*scope\s+shadow\s+.*$', line='scope shadow one')
        lineinfile(path = '/etc/nslcd.conf', regexp = r'^\s*filter\s+shadow\s+.*$', line='filter shadow (&(objectClass=posixAccount)(host=web))')
        lineinfile(path = '/etc/nslcd.conf', regexp = r'^\s*pam_authz_search\s+.*$', line='pam_authz_search (&(objectClass=posixAccount)(uid=$username)(host=web))')
        hubzero.utilities.misc.exShellCommand(['chmod', '0640', '/etc/nslcd.conf'])
        hubzero.utilities.misc.exShellCommand(['chown', 'root.nslcd', '/etc/nslcd.conf'])

        print "writing out /etc/nsswitch.conf"
        lineinfile(path = '/etc/nsswitch.conf', regexp = r'\s*passwd:\s*.*$', line='passwd: compat ldap')
        lineinfile(path = '/etc/nsswitch.conf', regexp = r'\s*shadow:\s*.*$', line='shadow: compat ldap')
        lineinfile(path = '/etc/nsswitch.conf', regexp = r'\s*group:\s*.*$', line='group: compat ldap')


        serviceInit('nslcd','stop')
        serviceInit('nslcd','start')
        serviceInit('nscd','stop')
        serviceInit('nscd','start')


    return(0)
    

parser = argparse.ArgumentParser(prog="hzsss")

subparsers = parser.add_subparsers()

parser_init = subparsers.add_parser('init', help='initialize nslcd/sss for ldap')
parser_init.set_defaults(func=_init_sss)

args =  parser.parse_args()

rc = args.func(args)

exit(rc)

