# @package      hubzero-mw2-common
# @file         host.py
# @author       Pascal Meunier <pmeunier@purdue.edu>
# @copyright    Copyright (c) 2016-2017 HUBzero Foundation, LLC.
# @license      http://opensource.org/licenses/MIT MIT
#
# Based on previous work by Richard L. Kennell and Nicholas Kisseberth
#
# Copyright (c) 2016-2017 HUBzero Foundation, LLC.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# HUBzero is a registered trademark of HUBzero Foundation, LLC.
#

"""
Support for communicating remotely using SSH.
"""
import os
import subprocess
import pwd
from log import log, user_print
from support import get_dirlock

from errors import PublicError, PrivateError
from constants import HOST_K, MW_USER

DEBUG = False
HOSTDISPLAYS_MAX = False

class BasicHost:
  """ Basic host functionality: SSH.  For the benefit of service (e.g., execution) hosts
  that don't create keys or deal with the "host" SQL table.
  """
  def __init__(self, hostname, remote_user, key_path):
    self.hostname = hostname
    self.remote_user = remote_user # MW_USER
    if key_path is None:
      raise PublicError("SSH key for '%s @ %s' not specified" % (remote_user, hostname))
    self.key_path = key_path

  #=============================================================================
  # SSH Operations Support
  #=============================================================================
  @staticmethod
  def knownhosts_path():
    # find the path to the correct known_hosts file
    # This could be run as root when running on an execution host
    # or as the Apache user on a web server.
    login = pwd.getpwuid(os.geteuid())[0]
    if login == 'root':
      # make the key available to everyone instead of putting it only in /root/.ssh/known_hosts
      return '/etc/ssh/ssh_known_hosts'
    else:
      # workaround for issue when called by expire-session daemon
      # running as www-data but somehow expanduser returns /root/.ssh/known_hosts !
      tmppath = os.path.expanduser('~/.ssh/known_hosts')
      if tmppath == '/root/.ssh/known_hosts':
        return '/var/www/.ssh/known_hosts'
      else:
        return tmppath

  def strict_ssh_host_key(self):
    """Return 0 if we have the remote host's key.
    """
    knownhosts_path = BasicHost.knownhosts_path()
    if DEBUG:
      log("""DEBUG: Checking key for host %s, knownhosts path is %s""" % (self.hostname, knownhosts_path))

    # grep for key
    cmd="""grep -q '%s[ ,]' %s < /dev/null""" % (self.hostname, knownhosts_path)
    return os.system(cmd)
    
  def check_ssh_host_key(self):
    """Make sure that we know the remote host's key. If not, add it to known_hosts.
    
    The error 'conalloc: fdno 262 too high' is fixed by closing excess open file 
    descriptors leaked by Apache.  The maxfdno is hard-coded in ssh-keyscan.  
    The same command runs fine in a shell.
    """
    # Check if key is known
    if self.strict_ssh_host_key() == 0:
      return

    # Add key
    log("""Can't find key for %s.  Attempting to add to known_hosts""" % self.hostname)
    # close excess open file descriptors leaked by Apache, to make sure we can call ssh-keyscan
    # see ticket 6546 on hubzero.org
    import fcntl
    for fd in range(3,256):
      try:
        fcntl.fcntl(fd, fcntl.F_SETFD, fcntl.FD_CLOEXEC)
      except:
        pass
    # status=os.system("""ssh-keyscan -t rsa %s >> %s""" % (self.hostname, knownhosts))
    p = subprocess.Popen(['ssh-keyscan', '-t', 'rsa', self.hostname], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    (stdout, stderr) = p.communicate()
    status =  p.returncode
    if status != 0:
      if DEBUG:
        os.system("""ls -al ~/.ssh/known_hosts; cd; pwd""")
        os.system("""whoami""")
      raise PublicError("SSH key for host '%s' unknown and can't be added to ~/.ssh/known_hosts" % (self.hostname))
    kh = open(BasicHost.knownhosts_path(), 'a')
    kh.write(stdout)
    kh.close


  def ssh(self, comm, feed=None):
    """Use ssh to run a command on a remote system. This assumes that self.key_path.pub
    is in the authorized keys of the remote user self.remote_user on the execution host.

    comm: a list or a string, the command to be executed.  Use a list for security, instead of
      space-separated arguments in a string.
    """

    self.check_ssh_host_key()

    if feed == "":
      feed = None

    if not isinstance(comm, list):
      comm = [comm]

    if DEBUG:
      log("ssh command: %s" % comm)

    popen_cmd = [
      '/usr/bin/ssh',
      '-i', self.key_path,
      "-o BatchMode=yes",  # suppresses prompts requesting passwords or adding ssh key etc...
      "-o ServerAliveInterval=60", # maximum time to wait for setup protocol from the server
      "-o ConnectTimeout=60", # for when the target is down or really unreachable
      '%s@%s' % (self.remote_user, self.hostname)
    ]
    popen_cmd.extend(comm)

    rcode = 0
    process = subprocess.Popen(
      popen_cmd,
      stdin= subprocess.PIPE, # prevents blocking if remote command expects input and we have none
      stdout= subprocess.PIPE,
      stderr= subprocess.PIPE,
      shell = False
    )
    (stdout, stderr) = process.communicate(feed)
    rcode = process.returncode
    if rcode != 0 or DEBUG:
      log("host.py: '%s' stdout '%s' stderr '%s' and returned code %d" % (" ".join(popen_cmd), stdout, stderr, rcode))
    return rcode

  def ask_ssh(self, comm, feed=None):
    """Use ssh to run a command on a remote system. This assumes that self.key_path.pub
    is in the authorized keys of the remote user self.remote_user on the execution host.

    comm: a list or a string, the command to be executed.  Use a list for security, instead of
      space-separated arguments in a string.

    If successful, return stdout.
    """

    self.check_ssh_host_key()

    if feed == "":
      feed = None

    if not isinstance(comm, list):
      comm = [comm]

    if DEBUG:
      for part in comm:
        log("part:" + str(part))

    popen_cmd = [
      '/usr/bin/ssh',
      '-i', self.key_path,
      "-o BatchMode=yes",  # suppresses prompts requesting passwords or adding ssh key etc...
      "-o ServerAliveInterval=60", # maximum time to wait for setup protocol from the server
      "-o ConnectTimeout=60", # for when the target is down or really unreachable
      '%s@%s' % (self.remote_user, self.hostname)
    ]
    popen_cmd.extend(comm)

    process = subprocess.Popen(
      popen_cmd,
      stdin= subprocess.PIPE,
      stdout= subprocess.PIPE,
      stderr= subprocess.PIPE,
      shell = False
    )
    (stdout, stderr) = process.communicate(feed)
    if process.returncode == 0:
      return stdout
    else:
      raise PrivateError("ask-ssh: ssh with '%s':%s:%s returned %d" % (self.hostname, stdout, stderr, process.returncode))

  def scp(self, src, dest):
    """Use scp copy one or more files."""
    # preflight
    self.check_ssh_host_key()

    popen_cmd = [
      '/usr/bin/scp',
      '-i', self.key_path,
      src,
      dest
    ]

    process = subprocess.Popen(popen_cmd, stderr= subprocess.PIPE, shell = False)
    (stdout, stderr) = process.communicate()
    if process.returncode != 0:
      raise PrivateError("Unable to scp from '%s' to '%s' (code:%d because %s; %s) using key '%s'"\
       % (src, dest, process.returncode, stdout, stderr, self.key_path))

  def scp_format(self, remote_file = None):
    if remote_file is None:
      return '%s@%s:' % (self.remote_user, self.hostname)
    else:
      return '%s@%s:%s' % (self.remote_user, self.hostname, remote_file)

class Host(BasicHost):
  """Functionality to manage the hosts running the "maxwell
  service" script, for the benefit of the maxwell script running on a master host.  For example,
  creating notify keys and transfering them to execution hosts.

  """

  def __init__(self, hostname, overrides={}):
    self.K = HOST_K
    self.K.update(overrides)
    BasicHost.__init__(self, hostname, self.K["SVC_HOST_USER"], self.K["KEY_PATH"])
    """ decouple database ops from filesystem ops:
       don't require a database connection just to setup the host
       make separate class for windows or set is_windows afterwards
    if HOST_K["WINDOWS_SUPPORT"]:
      row = self.is_a(db, 'windowshost')
      self.is_windows = (row is not None)
      if self.is_windows:
        self.service_path = self.K["WIN_SERVICE_PATH"]
      else:
        self.service_path = self.K["SERVICE_PATH"]
    else:
      self.is_windows = False
      self.service_path = self.K["SERVICE_PATH"]
    """
    self.is_windows = False
    self.service_path = self.K["SERVICE_PATH"]
   
  def get_status(self, db):
    return db.getsingle("""
      SELECT status FROM host
      WHERE hostname=%s""",
      [self.hostname])

  def set_status(self, db, new_status):
    db.c.execute("""
      UPDATE host
      SET status=%s
      WHERE hostname=%s""",
      [new_status, self.hostname])
    db.commit()
    log('set host status to %s' % new_status)

  def count_uses(self, db):
    """Recalculate count for given host, every time a display is added or removed.
      Display can be in status used, broken (a host with a
      lot of broken displays may not be reliable), or starting.
      status='absent' (after ending a session) is not to be counted."""
    db.c.execute("""
      UPDATE host SET uses=(
        SELECT COUNT(*) FROM display
        WHERE display.hostname=%s
          AND (display.status = 'used' OR display.status = 'starting' OR display.status = 'broken'))
      WHERE hostname=%s""" ,
      [self.hostname, self.hostname])
    db.commit()

  def uses_left(self, db):
    """Return max_uses-uses for this host"""
    return db.getsingle("""
      SELECT if(max_uses>0, max_uses-uses, 999-uses) FROM host
      WHERE hostname=%s""",
      [self.hostname])

  def max_uses(self, db):
    """Return max_uses for this host"""
    return db.getsingle("""
      SELECT if(max_uses>0, max_uses, 999) FROM host
      WHERE hostname=%s""",
      [self.hostname])

  def get_first_display(self, db):
    """Return the minimum display number to use for this host, or 1 if not set or set to 0"""
    first_display = db.getsingle("""SELECT first_display FROM host WHERE hostname=%s""",
      [self.hostname])
    if first_display is None:
      return 1
    v = int(first_display)
    if v == 0:
      return 1
    return v


  @staticmethod
  def get_host(db, prov_req, host_k = {}, zone = None):
    """Find out which hosts are available for display creation.
    This is a class method."""

    # Find the least loaded eligible host.
    # note #uses != load
    # note that library can't be used to escape prov_req because it will add quotes!
    max_uses = db.getsingle('select MAX(max_uses) as max from host', [])
    if max_uses is None or max_uses == 0:
      order_by = 'ORDER BY uses ASC'
    else:
      order_by = 'ORDER BY uses/max_uses ASC'
    #
    # calculate % load because max_uses varies from host to host
    # division by zero results in NULL, which is sorted before zero
    if zone is None:
      log("get_host: SELECT hostname, max_uses, uses FROM host WHERE provisions & %d = %d AND status='up' %s LIMIT 1 FOR UPDATE""" % (prov_req, prov_req, order_by))
      row = db.getrow(
        """SELECT hostname, max_uses, uses
           FROM host
           WHERE provisions & %d = %d
           AND status='up'
           %s LIMIT 1 FOR UPDATE""" % (prov_req, prov_req, order_by), [])
    else:
      log("get_host: SELECT hostname, max_uses, uses FROM host WHERE provisions & %d = %d AND status='up' AND zone_id = %d %s LIMIT 1 FOR UPDATE""" % (prov_req, prov_req, zone.zone_id, order_by))
      row = db.getrow(
        """SELECT hostname, max_uses, uses
           FROM host
           WHERE provisions & %d = %d
           AND status='up' AND zone_id = %d
           %s LIMIT 1 FOR UPDATE""" % (prov_req, prov_req, zone.zone_id, order_by), [])

    if row is None:
      return None
    uses = int(row[2])
    if row[1] is None:
      max_uses = 0
    else:
      max_uses = int(row[1])
    if max_uses == 0:
      return Host(row[0], host_k)
    if uses >= max_uses:
      return None
    return Host(row[0], host_k)


  def create_notify_keys(self, notify_key, quiet):
    """Create SSH notify keys.  This private/public key pair is transfered to execution hosts
    and is used for all sessions on that host (it's the same for all hosts too).  There is a race condition
    because we're using a constant path.  Another instance of maxwell could overwrite the key
    pair.  A checks -> no keys;  B does same;  A generates pair but doesn't write it yet, or writes only
    one of the two.  B generates pair and overwrites A's.  A finishes writing.  Final result: pair
    is broken because each key came from two different processes and belongs to a different pair.
    """
    Host.create_ssh_keys(quiet, notify_key)
    # adding new public key to list of authorized keys
    authkeys=os.environ["HOME"] + "/.ssh/authorized_keys"
    akfile = os.open(authkeys, os.O_WRONLY | os.O_CREAT | os.O_APPEND)
    pubkey = os.open(notify_key + '.pub', os.O_RDONLY)
    line = os.read(pubkey, 10000)
    # Adding forced command to authorized key
    line = "COMMAND=\"%s notify\" %s" % (self.K["MAXWELL_PATH"], line)
    os.write(akfile, line)

  @staticmethod
  def create_ssh_keys(quiet, path):
    """Create SSH keys."""
    if quiet:
      status = os.system("echo 'y' | ssh-keygen -t rsa -f %s -N '' -q" % path)
    else:
      user_print("Creating ssh keys")
      status = os.system("echo 'y' | ssh-keygen -t rsa -f %s -N ''" % path)
    if status != 0:
      # we use PublicError because we need results displayed on an admin interface.
      raise PublicError( "Unable to create ssh keys at '%s'." % path)


  def service(self, service_cmd, args, feed=None):
    """Call the maxwell service script on an execution host.  Preserve return value.

    Callers don't need to know details of how to call maxwell service on a host.  Also don't need to do
    preflight checks or remember them.
    args is an [].  Give an empty [] if there are none.

    feed is data to feed to remote service.  For example, Fileserver.update_resources does that."""

    self.check_ssh_keys(True)
    return self.ssh([self.service_path, service_cmd] + args, feed)

  def ask_service(self, service_cmd, args, feed=None):
    """Call the maxwell service script on an execution host.  Preserve return value.

    Callers don't need to know details of how to call maxwell service on a host.  Also don't need to do
    preflight checks or remember them.
    args is an [].  Give an empty [] if there are none.

    feed is data to feed to remote service.  For example, Fileserver.update_resources does that."""

    self.check_ssh_keys(True)
    return self.ask_ssh([self.service_path, service_cmd] + args, feed)

  def old_check_notify_keys(self, quiet):
    """ Obsolete code kept for reference
    Check that the SSH notify key exists. If it doesn't, get the directory lock and check again.
    If it still doesn't, create them.  The double check is to avoid the performance penalty of locking.
    It is going to be a very rare event that the keys don't exist.  The directory locking can sleep
    for seconds in case of contention, so that would not scale well.
    Use lock to prevent race condition in creating keys.
    This is identical to check_ssh_keys except that notify keys need to be transfered to exec hosts,
    and need to have a forced command setup.
    try:
      fd = os.open(self.K["NOTIFY_KEY"], os.O_RDONLY)
      os.close(fd)
    except OSError:
      # This time get the lock before checking
      lock_path = self.K["LOCK_PATH"] + "/notify_key_lock"
      get_dirlock(lock_path)
      try:
        fd = os.open(self.K["NOTIFY_KEY"], os.O_RDONLY)
        os.close(fd)
      except OSError:
        self.create_notify_keys(quiet)
        # Copying notify private key to host, to be used to contact us. (uses forced command)
        self.scp(self.K["NOTIFY_KEY"], self.scp_format(self.K["NOTIFY_KEY"]))
      # release lock now that the keys have been created and setup fully
      os.rmdir(lock_path)"""

  def check_notify_keys(self, quiet):
    """
    Check that the SSH notify key exists in the mw-www directory. If it doesn't exist, generate
    the notify key locally in the same location as HOST_K["KEY_PATH"] and transfer it to execution hosts.
    This check can't be done with the notify key in mw-service because we're www-data and it is owned by root
    it's a consequence of the separation of /etc/mw into mw-www and mw-service.
    To do it, get the directory lock and check again.
    If it still doesn't exist, create it.  The double check is to avoid the performance penalty of locking.
    It is going to be a very rare event that the keys don't exist.  The directory locking can sleep
    for seconds in case of contention, so that would not scale well.
    Use lock to prevent race condition in creating keys.
    This is identical to check_ssh_keys except that notify keys need to be transfered to exec hosts,
    and need to have a forced command setup."""
    maxwell_key = HOST_K["KEY_PATH"]
    notify_key = maxwell_key.replace("maxwell", "notify")
    if notify_key == maxwell_key:
      raise PrivateError("Can't create a new SSH notify key")
    try:
      fd = os.open(notify_key, os.O_RDONLY)
      os.close(fd)
    except OSError:
      # This time get the lock before checking
      lock_path = self.K["LOCK_PATH"] + "/notify_key_lock"
      get_dirlock(lock_path)
      try:
        fd = os.open(notify_key, os.O_RDONLY)
        os.close(fd)
      except OSError:
        self.create_notify_keys(notify_key, quiet)
        # Copying notify private key to host, to be used to contact us. (uses forced command)
        self.scp(notify_key, self.scp_format(self.K["NOTIFY_KEY"]))
      # release lock now that the keys have been created and setup fully
      os.rmdir(lock_path)

  def check_ssh_keys(self, quiet):
    """Check that the Maxwell SSH key pair exists.  Prevent race condition in the check and creation of keys.
    For example, KEY_PATH = '/etc/mw/maxwell.key'.  The public key (KEY_PATH + '.pub') needs to be
    installed on the file server.  This can't be done automatically, so if the keys didn't exist,
    maxwell will fail after creating the ssh keys.
    We use a database lock because we'll wait just the time necessary.  A directory lock would have us
    sleep some arbitrary amount of time between rechecks -- there is no queueing.
    """
    try:
      fd = os.open(self.key_path, os.O_RDONLY)
      os.close(fd)
    except OSError:
      # This time get the lock before checking
      lock_path = self.K["LOCK_PATH"] + "/ssh_key_lock"
      get_dirlock(lock_path)
      try:
        fd = os.open(self.key_path, os.O_RDONLY)
        os.close(fd)
      except OSError:
        Host.create_ssh_keys(quiet, self.key_path)
      os.rmdir(lock_path)
      raise PublicError("SSH keys created, you will need to transfer the public key '%s' to the host '%s' and put it in /.ssh/authorized_keys for user '%s'" % \
                        (self.key_path + '.pub', self.hostname, self.remote_user))

  def is_a(self, db, host_type_name):
    provisions = db.getsingle("""
      SELECT provisions FROM host JOIN hosttype WHERE host.provisions &
      hosttype.value AND hosttype.name =%s AND host.hostname=%s
      """, [host_type_name, self.hostname])
    return provisions

  def meets(self, db, hostreq):
    cmd = """SELECT provisions FROM host
      WHERE provisions & %d = %d""" % (hostreq, hostreq)
    if None == db.getsingle(cmd + """ AND host.hostname=%s""", [self.hostname]):
      return False
    else:
      return True

