#!/usr/bin/python
#
# @package      hubzero-mw2-client
# @file         maxwell
# @author       Pascal Meunier <pmeunier@purdue.edu>
# @copyright    Copyright (c) 2016-2017 HUBzero Foundation, LLC.
# @license      http://opensource.org/licenses/MIT MIT
#
# Based on previous work by Richard L. Kennell and Nicholas Kisseberth
#
# Copyright (c) 2016-2017 HUBzero Foundation, LLC.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# HUBzero is a registered trademark of HUBzero Foundation, LLC.
#

import MySQLdb
import os
import sys
import time
import pwd
import re
import stat
import subprocess
from hubzero.mw.mw_db import MW_DB
from hubzero.mw.log import log, print_n_log, user_print, dns_reverse, setup_log, log_exc, ttyprint, save_out, background, dns_resolve
from hubzero.mw.host import Host
from hubzero.mw.constants import MYSQL_CONN_ATTEMPTS, HOST_K, CONTAINER_K, VERBOSE, BKGND_RETRY_SLEEP, MASTER_LOG_FILENAME, \
  SESSION_K, APP_K, ALPHANUM, CONFIG_FILE, MW_USER, WEB_HOMEDIR, INT_REGEXP, IP_REGEXP, USER_REGEXP, \
  PATH_REGEXP, LOGSAFE_REGEXP, NAME_REGEXP, QUOTED_REGEXP, URL_REGEXP, GEOM_REGEXP
from hubzero.mw.errors import PublicError, DisplayError, PrivateError, \
  MaxwellError, InputError, ChildError, SessionError
from hubzero.mw.display import Display
from hubzero.mw.session import Session
from hubzero.mw.app import App
from hubzero.mw.fileserver import Fileserver
from hubzero.mw.support import genpasswd, check_rundir
# for zones support
from hubzero.mw.user_account import User_account
from hubzero.mw.zone import Zone

#=============================================================================
# Load default parameters...
#=============================================================================
ZONE_SUPPORT = False
mysql_host = "undefined"
mysql_user = "undefined"
mysql_password = "undefined"
mysql_db = "undefined"
mysql_prefix = "jos"
default_vnc_timeout = 900
session_suffix = ""
# specify either one of:
visualization_params = "" # old name
visualization_resource = "" # new name
# example visualization_params:
"""nanovis_server render00:2000,render01:2000,render02:2000
molvis_server render00:2020,render01:2020,render02:2020
submit_target tls://devsubmit.nanohub.org:831
submit_target tcp://devsubmit.nanohub.org:830
"""

filexfer_decoration = ""
hub_name = ""
hub_url = ""
hub_homedir = ""
hub_template = ""

dns_retries = 10
DISPLAY_RETRIES = 4
homedir_retries = 10
get_stats_retries = 3
get_stats_delay = 3
default_vnc_depth = 24
# match displays by geometry and depth.  Set to false to need fewer ready displays
STRICT_GEOMETRY = True

default_version = 'current'
SESSION_CONF = {}
CONTAINER_CONF = {}
HOST_CONF = {}
APP_CONF = {}
submit_app_list = []
MIN_READY_DISPLAYS = 1
LOG_ID = 'mw'
CREATE_SESSIONS_ONCHECK = False
APPS_SERVER = None
APPS_DIR = '/apps/%s'

#=============================================================================
# Host checkup.
#=============================================================================

def check_host(hostname, confirm):
  """Toggle host status (up/down).  Before bringing up, do an SSH key check on the specified host.
  Optionally, accept a new key.
  Attempt to contact host and diagnose problems.  Note that we want to get
  back information.  So the handling of db connections is such that we expect
  the child to close the connection by calling sys.exit() if everything goes
  well.
      This is a top-level function called right after command parsing.  Because this is called
      from the web admin interface, user messages get full information (so don't raise errors except
      for timeout)
  """

  db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)

  h = Host(hostname)
  host_status = h.get_status(db)

  if host_status is None:
    print_n_log("Unknown host '%s'" % hostname)
    return 0

  if host_status == 'up':
    # toggle host status.
    h.set_status(db, 'down')
    print_n_log( "Brought down host %s" % hostname)
    # we're done!
    return 0

  # Now we're trying to bring the host 'up'.
  # It must pass the following checks, which allow us a diagnostic of the ssh connection attempt.

  print_n_log( "Checking host %s" % hostname)
  try:
    h.check_ssh_keys(False)
  except PublicError, p:
    print_n_log("%s" %p)
    print_n_log("Try again after this is done")
    return 0
  h.check_ssh_host_key()

  if ZONE_SUPPORT:
    # check if host is in local or remote zone.  If remote, then it's a remote master;  use zone class. 
    try:
      # this should fail if host isn't a remote master
      z = Zone.get_zone_by_master(db, hostname)
      stdout = z.ask(db, ["check"])
      if "OK" in stdout:
        h.set_status(db, 'up')
        print_n_log("Host check on '%s' was successful." % hostname)
        return 0
      else:
        h.set_status(db, 'sshfail')
        print_n_log("Host check on '%s' returned '%s'" % (hostname, stdout)) 
        return -1
    except MaxwellError:
      pass

  # identical to the Host ssh function, except the timeout is 5
  # REVISIT: use a factory pattern for host?
  if h.is_a(db, 'fileserver'):
    path = HOST_MERGED["FS_PATH"]
  else:
    if h.is_windows:
      path = HOST_MERGED["WIN_SERVICE_PATH"]
    else:
      path = HOST_MERGED["SERVICE_PATH"]

  popen_cmd = [
    '/usr/bin/ssh',
    '-i', HOST_MERGED["KEY_PATH"],
    "-o BatchMode=yes",  # suppresses prompts requesting passwords or adding ssh key etc...
    "-o ConnectTimeout=5", # for when the target is down or really unreachable
    '%s@%s' % (HOST_MERGED["SVC_HOST_USER"], hostname),
    path,
    "check"
  ]

  if VERBOSE:
    for part in popen_cmd:
      log( "part:" + str(part))

  process = subprocess.Popen(
    popen_cmd,
    stdout= subprocess.PIPE,
    stderr= subprocess.PIPE,
    shell = False
  )

  (stdout, stderr) = process.communicate()

  if process.returncode == 0:
    if "OK" in stdout:
      h.set_status(db, 'up')
      print_n_log("Host check on '%s' was successful." % hostname)
    else:
      h.set_status(db, 'sshfail')
      print_n_log("Host check on '%s' returned '%s' without error?" % (hostname, stdout))
      return -1
  else:
    print_n_log("ssh to '%s' returned '%s%s'" % (hostname, stdout, stderr))
    if ('Host key verification failed' in stderr):
      if confirm == 'yes':
        # try again, telling ssh to accept the remote ssh key
        popen_cmd.insert(3, "-o StrictHostKeyChecking no")
        process = subprocess.Popen(
          popen_cmd,
          stdout= subprocess.PIPE,
          stderr= subprocess.PIPE,
          shell = False
        )
        (stdout, stderr) = process.communicate()
        if process.returncode == 0:
          h.set_status(db, 'sshtest')
          print_n_log( "You need to confirm the host key.")
          # keep on going to next step
        else:
          h.set_status(db, 'sshkey')
          print_n_log("Unable to accept host key for '%s': '%s'" % (hostname, stderr))
          return -1
      else:
        h.set_status(db, 'sshkey')
        print_n_log("Unable to confirm the host key.")
        return -1
    else:
      h.set_status(db, 'sshfail')
      print_n_log( "Unable to connect to %s." % (hostname))
      print_n_log( "You'll have to fix that.")
      return -1

  try:
    h.check_notify_keys(False) # create if necessary
    # The notify private key is used to contact us. (uses forced command)
    # Even if notify keys exist, it may not have been transfered to this specific host yet.
  except MaxwellError, scpe:
    print_n_log("%s" % scpe)
    return 1

  # create ready displays for all apps,  as needed
  # this needs more testing
  # if it is an execution host.  How do you know?
  #if not STRICT_GEOMETRY:
  #  avail = int(Display.get_ready_count(db, h, None, True))
  #  if avail < MIN_READY_DISPLAYS:
  #    # if not, create the required number.
  #    create_display_background(db, h, None) # closes db
  return 0

#=============================================================================
# VNC operations.
#=============================================================================
def find_display(db, appinfo, sessnum, zone):
  """Find an appropriate container with VNC server.  That's a "display" configured and ready to go.
  if we fork to create things in background, close db.
  db: a MW_DB object
  appinfo: an App object
  sessnum: a session number
  zone: a local Zone object (can be None)
  returns a Display object."""

  if VERBOSE:
    log("find_display: %s session %d" % (appinfo.appname, sessnum))
  disp = None
  count = 0
  while disp is None and count <  DISPLAY_RETRIES:
    # we loop at this level because get_Display can create new displays if none are available.
    # Creating can fail, increase the broken count, and so increase the usage of a host
    # So, a different host could be chosen in every loop.
    try:
      disp = Display.get_Display(db, appinfo, sessnum, host_k = HOST_MERGED, zone = zone, strict_geo = STRICT_GEOMETRY)
    except DisplayError, ex:
      print_n_log("%s" % ex)
    count += 1

  if disp is None:
    raise PublicError("Unable to find a display for '%s'. Has the app been setup?" % appinfo.appname)

  if int(Display.get_ready_count(db, disp.host, appinfo, STRICT_GEOMETRY)) < MIN_READY_DISPLAYS:
    create_display_background(db, disp.host, appinfo) # closes db

  if VERBOSE:
    log("found display %d" % disp.dispnum)

  return disp

def create_display_background(db, host, appinfo):
  """Create an appropriate display in the background (for later use). appinfo is an instance of the App class,
  not a list."""

  if host.is_windows:
    return
  if VERBOSE:
    log("create_display_background on '%s'" % host.hostname)

  try:
    if not background():
      # If we're the original parent, return.  Child continues below.
      # parent returns 0.
      return
  except ChildError, ex:
    # didn't become fully dissociated
    log("create_display_background: %s" % ex)
    os._exit(1)


  # Execute the following as a dissociated child.
  # wrap everything in try to make sure to call os._exit()
  try:
    # interactive tasks have higher priority, so be nice to them
    os.nice(20)
    setup_log(MASTER_LOG_FILENAME, LOG_ID + "_bckgnd")
    # log("create_display_background: child starting")
    # this part creates new displays
    for i in range(0, DISPLAY_RETRIES):
      # return CPU to main process for a while
      time.sleep(BKGND_RETRY_SLEEP)
      # We forked so we need our own db connection, and we don't sleep while holding a database connection because # connections is limited
      db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db)
      # an exception on the execution host could bubble up here
      try:
        left = host.uses_left(db)
        avail = int(Display.get_ready_count(db, host, appinfo, STRICT_GEOMETRY))
        if avail >= MIN_READY_DISPLAYS:
          #log("create_display_background: enough displays available (%d), exiting" % (avail))
          db.db.close()
          os._exit(0)
        #else:
          #log("create_display_background: want more than %d displays (up to %d)" % (avail, MIN_READY_DISPLAYS))
        if avail >= left:
          log("create_display_background: running maximum number of containers on '%s'" % host.hostname)
          db.db.close()
          os._exit(0)
        disp = Display.make_Display(db, host, appinfo, 0)
      except MaxwellError, ex:
        log("create_display_background: %s" % ex)
      finally:
        db.db.close()
    log("create_display_background no more retries")
    os._exit(1)
  except StandardError, ex:
    log("create_display_background: %s" % ex)
    os._exit(1)

def stop_session_background(sessnum, reason):
  """Stop a session in the background (no waiting).  The parent returns immediately.
    The child never returns.
    This is a top-level function called right after command parsing."""
  # Execute the following as a dissociated child.
  if not background():
    return 0
  # interactive tasks have higher priority, so be nice to them
  try:
    os.nice(20)
    db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)
    sess = Session(sessnum, session_suffix, SESSION_MERGED)
    # start the session shutdown but do not make display "absent" until notify command is
    # received from the execution host
    sess.stop_session(db, reason)
    # database integrity checks
    Session.integrity_check(db)
  except PrivateError, ex:
    import traceback
    trbk = sys.exc_info()[2]
    log("stop_session_background: %s\n%s" % (ex, traceback.format_tb(trbk, 10)))
    # perhaps session has already been processed
  except StandardError, ex:
    import traceback
    trbk = sys.exc_info()[2]
    log("stop_session_background: %s\n%s" % (ex, traceback.format_tb(trbk, 10)))
    os._exit(1)

  os._exit(0)

def session_exit_notify(sessnum, fromhost, reason="exited"):
  """Notification that the session has exited asynchronously.
  This is a top-level function called right after command parsing.
  sess is a number (no suffix).
  So, we need to clean up as needed (close display)."""

  db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)

  # Clean up session if needed
  # this deletes from the session table, which contains only active sessions.
  # data is moved to the sessionlog table, which we'll use later
  sess = Session(sessnum, session_suffix, SESSION_MERGED)
  if reason == "exited":
    sess.get_sessionhost(db)
    if fromhost.find('127.0.') != 0:
      hostname = sess.host.hostname
      if hostname != "localhost":
        verify_ip = dns_resolve(hostname)
        if verify_ip != fromhost:
          raise MaxwellError("notify event source mismatch.  Notified %d from %s instead of %s" % (sessnum, fromhost, verify_ip))
        else:
          log("Notified %d from %s (%s)" % (sessnum, fromhost, dns_reverse(fromhost)[0:99]))
    try:
      sess.stop_session(db, reason)
    except SessionError, se:
      if VERBOSE:
        log("the session has already been processed, but trying again to transfer session log files")
        # do not exit at this point, need to process stats!

  sess.get_stats(db, get_stats_retries, get_stats_delay)  # (self, db, get_stats_retries, get_stats_delay)
  sess.process_stats(db)
  db.db.close()
  return 0

def view_applet(sessnum, user, ip, force_ro):
  """Print the HTML code to load the applet.
      This is a top-level function called right after command parsing.
      sess is the session number
      REVISIT: ip is not used anymore and should be deleted, but it is present in API..."""

  db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)

  sess = Session(sessnum, session_suffix, SESSION_MERGED)
  try:
    sess.get_viewperms(db, user)
  except SessionError:
    user_print("Unable to view application.")
    return 0

  if SESSION_MERGED["NOVNC"]:
    sess.host.service("set_viewtoken", [sess.disp.dispnum, sess.viewtoken])
  sess.get_app(db, mysql_prefix)
  # log("force_ro is %d" % force_ro)
  sess.applet_html(force_ro, hub_url, db)

def update_quota(user, block_soft, block_hard):
  """Talk to the file server to change a user's quotas.
  This is a top-level function called right after command parsing."""

  block_soft = int(block_soft)
  block_hard = int(block_hard)

  if VERBOSE:
    log("update_quota %s, %s, %s" % (str(user),str(block_soft),str(block_hard)))

  if block_soft < 0 or block_hard < 0:
    raise InputError("Invalid quotas")

  users = user.split(",")

  db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)
  for check_user in users:
    try:
      userid = int(check_user)
      if userid < 1000:
        log("User id must be > 1000\n")
        ttyprint("Quota change denied: bad userid (%d).\n" % (userid))
        sys.exit(6)
    except ValueError:
      user_row = db.getrow(
        "SELECT username FROM " + mysql_prefix + """_users
        WHERE username=%s""", str(check_user))
      if user_row is None:
        log("Quota denied: No such HUB user (%s).\n" % (check_user))
        ttyprint("Quota change denied: No such HUB user (%s).\n" % (check_user))
        sys.exit(6)
  fs = Fileserver(db, user, 0, 0)
  fs.remote("update_quota", [user, "%d" % block_soft, "%d" % block_hard])

def get_quota(user):
  """Talk to the file server to get a user's quotas;  print to stdout for
  the benefit of the web server.
  This is a top-level function called right after command parsing."""

  if VERBOSE:
    log("get_quota %s" % (str(user)))

  db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)

  fs = Fileserver(db, user, 0, 0)
  ttyprint(fs.ask_remote("get_quota", [user]))

def erase_userhome(user):
  """Talk to the file server
  This is a top-level function called right after command parsing."""

  if VERBOSE:
    log("erase_userhome %s" % (str(user)))

  db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)

  fs = Fileserver(db, user, 0, 0)
  ttyprint(fs.ask_remote("erase_userhome", [user]))


def move_userhome(userfrom, userto):
  """Talk to the file server
  This is a top-level function called right after command parsing."""

  if VERBOSE:
    log("move_userhome %s" % (str(userfrom)))

  db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)

  fs = Fileserver(db, user, 0, 0)
  ttyprint(fs.ask_remote("move_userhome", [userfrom, userto]))

def create_userhome(user):
  """Talk to the file server
  This is a top-level function called right after command parsing."""

  if VERBOSE:
    log("create_userhome %s" % (str(user)))

  db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)

  fs = Fileserver(db, user, 0, 0)
  ttyprint(fs.ask_remote("create_userhome", [user]))

def remote_zone_create(db, z, user, ip, appname, sess, viewtoken):
  """Create a session in a remote zone"""
  try:
    # send user and group info
    uinfo = User_account(user)
    log("setting up %s" % user)
    z.tell(db, ["setup", user] + uinfo.groups())
    log("start user=%s ip=%s app=%s sessnum=%s viewtoken=%s sesstoken=%s" % (user, ip, appname, sess.sessnum, viewtoken, sess.sesstoken))

    # start session, receive vncpassword back
    html = z.ask(db, ["start user=%s ip=%s app=%s sessnum=%s viewtoken=%s sesstoken=%s" % (user, ip, appname, sess.sessnum, viewtoken, sess.sesstoken)])

    # grab vncpass from output
    # if JSON output, expecting ''{"encpassword": "c1a0583062623fb1", "encoding": "ZRLE", "archive":....'
    enc_loc = html.find('encpassword": "')
    if enc_loc == -1:
      # try old HTML method
      # expecting <param name="ENCPASSWORD" value="%s">
      enc_loc = html.find('ENCPASSWORD" value="')
      if enc_loc == -1:
        raise MaxwellError("Unable to find vnc password (ENCPASSWORD) from remote output")
      log(enc_loc)
      # 20 characters in 'ENCPASSWORD" value="' which is start of value
      value_loc = 20 + enc_loc
      log(value_loc)
      end_loc = html.find('">', value_loc)
      log(end_loc)
      if end_loc == -1:
        raise MaxwellError("Unable to find vnc password (quote) from remote output")
    else:
      # 15 characters in 'encpassword": "' which is start of value
      value_loc = 15 + enc_loc
      log(value_loc)
      end_loc = html.find('",', value_loc)
      log(end_loc)
      if end_loc == -1:
        raise MaxwellError("Unable to find vnc password (quote) from remote output")      
    vncpass = html[value_loc:end_loc]
    log("vnc password %s" % vncpass)
    # store viewtoken.disp.vncpass
    sess.set_viewperm(db, viewtoken, cookie, vncpass)
    log("vnc password stored")
  except StandardError, ex:
    sess.del_session(db, ex)
    log("Aborted create_session")
    raise MaxwellError("Unable to start remote session")
  if html is None:
    ttyprint("No applet HTML!")
  else:
    #  Print the html from the remote maxwell instead of our own
    ttyprint(html)

def undo_stage1(db, sess, ex):
  if sess is not None:
    db.unlock("stop_session")
    sess.del_session(db, ex)

def undo_stage2(db, disp):
  if disp is not None:
    disp.ready(db)

def undo_stage3(db, fs, sess):
    if fs is not None: # Fileserver cleanup
      try:
        fs.remote("erase_sessdir", [fs.user, sess.sessname()])
      except StandardError:
        # if we're dealing with a host using the old maxwell_service, keep going
        log("erase_sessdir call to maxwell_service on host '%s' failed.  Is it running an old maxwell_service script?" \
            % fs.host.hostname)

def undo_stage4(db, sess, ex):
  # deletion of dependent table rows also done when session is deleted so this will be fixed by undo_stage1
  pass

def create_session(user, ip, appname, timeout, version, appopts, params, zone, template, confbase):
  """Create a session.
  -Calculate host requirements from the requested app
  -If zone is specified, check that the requested zone supports that application
  -If zone is not specified, find one that supports the application
  -Is this a local or remote session?
  -If remote, create session database entry, call remote site, and return sessnum.
  -else, find a display and call maxwell_service.
  This is a top-level function called right after command parsing."""
  stage = 0 # undo levels

  if VERBOSE:
    log("create_session %s, %s, %s, %s, %s, %s" % (str(user),str(ip),str(appname),str(timeout),str(version), str(appopts)))
    log("create_session %s, %s, %s, %s" % (str(params), str(zone),str(template),str(confbase)))

  db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)

  app = App(db, appname, APP_MERGED, mysql_prefix)

  if ZONE_SUPPORT:
    if zone == "" or zone is None:
      # find a local zone
      # what if there is more than one?  Would need to check host provisions
      z = Zone.find(db, app)
      if z is None:
        raise PublicError("Error: no zone available for '%s'" % (appname))
      if z.zone_id == 0:
        raise PublicError("Error: zone_id is zero")
    else:
      z = Zone.get_zone_by_name(db, zone)
      if not z.supports(db, app):
        raise PublicError("Error: zone '%s' does not support '%s'" % (z.zone, appname))

    if not z.is_up(db):
      raise PublicError("Error: zone '%s' is not up." % z.zone)
  else:
    z = None

  # session cookie
  cookie = genpasswd(16, ALPHANUM)
  # no funny characters in viewtoken, it's used on the command line and in SQL queries
  viewtoken = genpasswd(32, ALPHANUM)

  stage = 1
  sess = None
  disp = None
  fs = None

  # wrap everything in a "try"
  try:
    if ZONE_SUPPORT:
      sess = Session.create_indb_zone(db, user, ip, app, session_suffix, z, SESSION_MERGED)
      if not z.is_local(db):
        remote_zone_create(db, z, user, ip, appname, sess, viewtoken)
        return 0
    else:
      sess = Session.create_indb(db, user, ip, app, timeout, session_suffix, SESSION_MERGED)
  except StandardError, ex:
    log("create_session: error at stage 1: %s" % (stage, ex))
    session_cleanup(stage, sess, disp, fs, ex)
    raise PublicError("Unable to setup session database entry.")

  stage = 2 # from now on, aborting requires delete from session
  try:
    # Find and book a suitable VNC display.
    disp = find_display(db, app, sess.sessnum, z) # closes db
  except StandardError, ex:
    session_cleanup(stage, sess, disp, fs, ex)
    raise PublicError("Unable to get a display.")
    
  stage = 3 # From now on, aborting creation requires freeing display
  try:
    db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)
    sess.set_exec(db, disp)
  except StandardError, ex:
    session_cleanup(stage, sess, disp, fs, ex)
    raise PublicError("Unable to assign display to session.")

  if user == "anonymous":
    try:
      # anonymous users have a volatile home directory
      # no file server for anonymous:  use the execution host, and call maxwell_service
      fspath = {"FS_PATH":HOST_MERGED["SERVICE_PATH"]}
      fs = Fileserver(db, user, sess, homedir_retries, overrides = fspath, host = disp.host)
      # pass the display number too
      stage = 33
      exit_code = fs.remote("anonymous", [sess.sessname(), '%s' % disp.dispnum, params])
    except StandardError, ex:
      undo_stage2(db, disp)
      undo_stage1(db, sess, ex)
      raise PublicError("Unable to create anonymous home directory")
    if exit_code != 0:
      session_cleanup(stage, sess, disp, fs, ex)
      raise PublicError("Unable to create anonymous home directory")
  else:
    try:
      fs = Fileserver(db, user, sess, homedir_retries, zone = z)
      # Create user's session directory and resource file.  Set up the user's home directory if necessary.
      fs.setup_account(db, params)
    except StandardError, ex:
      session_cleanup(stage, sess, disp, fs, ex)
      raise PublicError("Unable to setup session home.")
    stage += 1 # Fileserver will require cleanup
    
  try:
    log("Starting %d (%s) for %s on %s:%d" % (sess.sessnum, appname, user, disp.host.hostname, disp.dispnum))

    # Update the resources file.
    # check which of visualization_resource or visualization_params has been specified
    # we want to stop using "visualization_params"
    if visualization_resource == "":
      vr = visualization_params
    else:
      vr = visualization_resource
    c = CONTAINER_K
    c.update(CONTAINER_CONF)
    fs.update_resources(db, app, submit_app_list, version, vr, hub_name, hub_url, hub_template, disp.dispnum + c["FILEXFER_PORTS"], cookie)
    # no additional stage because the resources file was already created to get to stage 3

    if VERBOSE:
      log("Updated resources")
  except StandardError, ex:
    session_cleanup(stage, sess, disp, fs, ex)
    raise PublicError("Unable to create resource file.")


  try:    # vncpass argument is optional, and is grabbed from display if not specified
    sess.set_viewperm(db, viewtoken, cookie)
  except StandardError, ex:
    session_cleanup(stage, sess, disp, fs, ex)
    raise PublicError("Unable to create resource file.")

  stage += 1 # from now on, aborting requires deleting from viewperm
  try:
    # Session info and applet is displayed below...
    # CMS is reading the output and expecting the session number information
    # exactly in this format!  Do not change this line...
    ttyprint("Session is %d<br>" % sess.sessnum)

    if os.isatty(1):
      if VERBOSE:
        log("session is a tty, so didn't send applet HTML")
    else:
      sess.applet_html(False, hub_url, db)
      if VERBOSE:
        log("sent applet HTML")
  except StandardError, ex:
    session_cleanup(stage, sess, disp, fs, ex)
    raise PublicError("Unable to provide session information back to CMS.")
  
  # Get list of foreign resources that will need to be mounted but do not mount them yet
  # because we don't want to hold a database connection while doing network ops
  sess.get_foreign(User_account(user), db)

  try:
    db.db.close()
    if not background():
      return 0
  except ChildError, ex:
    log("create_session: Didn't become fully dissociated: %s" % ex)
    session_cleanup(stage, sess, disp, fs, ex)
    os._exit(1)
  except StandardError, ex:
    if VERBOSE:
      log_exc(ex)
    session_cleanup(stage, sess, disp, fs, ex)
    raise PublicError("create_session: error at stage %d: %s" % (stage, ex))

  # Execute the following as a dissociated child.
  stage += 1 # from now on, aborting requires stopping the display or setting it to the broken state
  
  # Mount foreign resources
  # this closes the database connection before performing the mount operations
  # abort after this point needs to unmount filesystems in container
  try:
    # Mount foreign resources in container
    sess.mount_foreign()
    # Start tool
    sess.invoke_command(appopts, params)
  except StandardError, ex:
    session_cleanup(stage+1, sess, disp, fs, ex)
    os._exit(1)

  if VERBOSE:
    log("create_session done: %d" % sess.sessnum)
  os._exit(0)

def session_cleanup(stage, sess, disp, fs, ex):
  try:
    # because our previous "db" connection may hold locks and we don't know which
    # close it and get a new one 
    db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)

    if disp is not None:
      # anonymous displays need to be stopped at all stages
      if stage > 30:
        # anonymous sessions have stage > 30
        disp.stop(db)
        stage -= 30
      elif stage > 5:
        # set display to broken state at stage 6, something went wrong.  Perhaps the app is broken and the container is fine,
        # but perhaps something is wrong with that container that prevents starting an app...
        disp.broken(db)
      else:
        # stage 3, 4 or 5, just free display
        disp.ready(db)

    if fs is not None and stage >= 3: # Fileserver cleanup
      try:
        fs.remote("erase_sessdir", [fs.user, sess.sessname()])
      except StandardError:
        # if we're dealing with a host using the old maxwell_service, keep going
        log("erase_sessdir call to maxwell_service on host '%s' failed.  Is it running an old maxwell_service script?" \
            % fs.host.hostname)

    if sess is not None: # stage 2 delete from session table
      # stage 5 delete dependent tables also done when session is deleted
      sess.del_session(db, ex)

    if sess is not None:
      log("Session %d cleanly aborted at stage %d" % (sess.sessnum, stage))
    else:
      log("Cleanly aborted create_session")

  except StandardError, boom:
    if VERBOSE:
      log_exc(boom)
    log("Error aborting session %d at stage %d (%s)" % (sess.sessnum, stage, boom))


def check_db(make_changes=False):
  """Verify database integrity.
  verify that there aren't rows with session numbers that don't exist in session table
  displays for hosts that don't exist, etc...
  This is to investigate the usefulness of foreign key statements."""

  db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)

  report = 'Found %d bad rows in %s'

  # Look for view permissions without sessions
  arr = db.getall("""
    SELECT viewperm.sessnum
    FROM viewperm LEFT OUTER JOIN session ON viewperm.sessnum = session.sessnum
    WHERE session.sessnum is NULL""",
    ())
  count = 0
  for row in arr:
    if make_changes:
      db.c.execute("DELETE FROM viewperm WHERE sessnum='%d'", row[0])
    count += 1
  user_print(report % (count, "viewperm"))

  # Look for views without sessions
  arr = db.getall("""
    SELECT view.sessnum
    FROM view LEFT OUTER JOIN session ON view.sessnum = session.sessnum
    WHERE session.sessnum is NULL""",
    ())
  count = 0
  for row in arr:
    if make_changes:
      db.c.execute("DELETE FROM view WHERE sessnum='%d'", row[0])
    count += 1
  user_print(report % (count, "view"))

  # Look for displays without hosts
  arr = db.getall("""
    SELECT display.hostname
    FROM display LEFT OUTER JOIN host ON display.hostname = host.hostname
    WHERE host.hostname is NULL""",
    ())
  count = 0
  for row in arr:
    if make_changes:
      db.c.execute("DELETE FROM display WHERE sessnum='%d'", row[0])
    count += 1
  user_print(report % (count, "display no hosts"))

  # Look for used displays without sessions
  arr = db.getall("""
    SELECT display.dispnum
    FROM display LEFT OUTER JOIN session ON display.sessnum = session.sessnum
    WHERE display.status = 'used' and session.sessnum is NULL""",
    ())
  count = 0
  for row in arr:
    if make_changes:
      db.c.execute("DELETE FROM display WHERE sessnum='%d'", row[0])
    count += 1
  user_print(report % (count, "display no sessions"))

  # Look for sessions with empty usernames
  if make_changes:
    db.c.execute("""DELETE FROM session WHERE username = '' """)
    nrows = db.getsingle("""SELECT ROW_COUNT()""", ())
    if nrows > 0:
      print "cleaned up %d rows in session table with empty user names" % nrows
  else:
    row = db.getsingle("""SELECT count(*) FROM session WHERE username = ''""", ())
    print "there are %s sessions with empty usernames" % row[0]

  # Look for fileperm without sessions
  arr = db.getall("""
    SELECT fileperm.sessnum
    FROM fileperm LEFT OUTER JOIN session ON fileperm.sessnum = session.sessnum
    WHERE session.sessnum is NULL""", ())
  count = 0
  for row in arr:
    if make_changes:
      db.c.execute("DELETE FROM fileperm WHERE sessnum='%d'", row[0])
    count += 1
  user_print(report % (count, "fileperm no sessions"))

def renotify():
  """Process session logs again (<sessnum>.err)

  To do:  add check to see if execution host is up
  also, don't contact submit host!
  """
  db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)
  rows = db.getall("""
    SELECT s.sessnum FROM sessionlog s LEFT OUTER JOIN joblog j ON s.sessnum = j.sessnum
    WHERE j.sessnum IS NULL AND s.start > TIMESTAMPADD(year, -1, NOW())
    AND exechost != 'submit' AND exechost != '' AND exechost not like %s""", ('win%'))
  for row in rows:
    if VERBOSE:
      log("trying again to process stats for:" + str(row[0]))
    try:
      session_exit_notify(row[0], '127.0.0.1', "renotify")
    except (MaxwellError, OSError):
      # e.g., No such file or directory: '/var/log/mw/sessions/5.out
      # getaddrinfo exec018: Name or service not known
      # SSH key for host 'exec018' unknown and can't be added to ~/.ssh/known_hosts
      pass
    time.sleep(0.2)

  rows = db.getall("""SELECT sessnum FROM sessionlog WHERE cputime = 0
    AND start > TIMESTAMPADD(year, -1, NOW())
    AND exechost != 'submit' AND exechost != '' AND exechost not like %s""", ('win%'))
  for row in rows:
    if VERBOSE:
      log("trying again to process stats for:" + str(row[0]))
    try:
      session_exit_notify(row[0], '127.0.0.1', "renotify")
    except (MaxwellError, OSError): # e.g., No such file or directory: '/var/log/mw/sessions/5.out
      pass
    time.sleep(0.2)

def shutdown():
  """To be called when a hub is shutting down, or after a reboot following a crash"""
  db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)
  rows = db.getall("""SELECT sessnum, exechost FROM session""", ())
  last_exechost = 0
  for row in rows:
    sessnum = row[0]
    exechost = row[1]
    if VERBOSE:
      log("shutting down session: %d on %s" % (sessnum, exechost))
    # it takes too long to wait on submit --local, do the background shutdown in parallel
    stop_session_background(sessnum, 'shutdown')
    if exechost != last_exechost:
      # don't sleep when shutting down sessions on different execution hosts
      last_exechost = exechost
      continue
    time.sleep(10)
  # wait for all displays to have stopped and log files to have been processed
  attempt = 1
  while attempt < 2 and len(db.getall("""SELECT dispnum FROM display where status = 'stopping'""", ())) > 0:
    log("waiting on all displays to have stopped")
    time.sleep(60)
    attempt += 1
  renotify()
  print "displays left in stopping state:"
  print db.getall("""SELECT dispnum FROM display where status = 'stopping'""", ())
  # Windows execution hosts have "persistent" displays
  db.c.execute("""DELETE FROM display where hostname not like %s""", ('win%'))
  db.c.execute("""DELETE FROM viewperm""", ())
  db.c.execute("""DELETE FROM view""", ())

def purge(hostname = None):
  """shuts down all idle containers and starts the appropriate number of new containers.
  This is handy when we make changes to templates or to session setups and want to push them out.
  """
  if hostname == "":
    hostname = None
  db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)
  Display.purge(db, hostname)

def bork():
  """attempts to shut down broken containers and restart them
  """
  db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)
  Display.bork(db)

def verify():
  """verify that running containers really are
  """
  db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)
  Display.verify(db)

def screenshot(username):
  """This calls maxwell_service to take a screenshot for all sessions a user has
  Include shared sessions.
  """
  if background():
    db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)
    rows = db.getall("""SELECT exechost, session.sessnum, dispnum, username FROM session JOIN viewperm USING (sessnum) WHERE viewuser = %s""", username)
    # close the database connection before trying to contact file server
    # to avoid consuming too many when things go wrong
    db.db.close()
    session_format = '%d' + session_suffix
    for row in rows:
      h = Host(row[0], {})
      h.service("screenshot", [username, session_format % row[1], '%d' % row[2]])
      if username != row[3]:
        h.service("screenshot", ['%s' % row[3], session_format % row[1], '%d' % row[2]])

def resize(sessnum, geom):
  """Resize a tool session """
  db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)
  row = db.getrow("""SELECT hostname, dispnum FROM display WHERE sessnum = %s""", sessnum)
  h = Host(row[0], {})
  h.service("resize", ['%s' % row[1], '%s' % geom])

def setfacl(toolid, db=None):
  """construct tool facl based on development group and allowed groups in database.
  """
  if db is None:
    db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)
  toolname = db.getsingle("""SELECT toolname FROM """+ mysql_prefix + """_tool WHERE id = %s""", toolid)
  facl = "user::rwx,group::rwx,other::---,mask::rwx"
  groups = db.getall("""SELECT cn FROM """+ mysql_prefix + """_tool_groups WHERE toolid = %s AND role < 2""", toolid)
  for group in groups:
    facl += ",group:%s:r-x" % group
  if APPS_SERVER is not None:
    # if using a different fileserver for /apps and /home (e.g., nanohub)
    apps_server = Host(APPS_SERVER)
    apps_server.ssh([HOST_MERGED["FS_PATH"], 'setfacl', facl, APPS_DIR % toolname], None)
  else:
    fs = Fileserver(db, None, 0, 0)
    fs.remote('setfacl', [facl, APPS_DIR % toolname])


def unsetfacl(toolid, db=None):
  """remove group restrictions for a tool.
  """
  if db is None:
    db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)
  toolname = db.getsingle("""SELECT toolname FROM """+ mysql_prefix + """_tool WHERE id = %s""", toolid)
  facl = "user::rwx,group::rwx,other::r-x,mask::rwx"
  if APPS_SERVER is not None:
    # if using a different fileserver for /apps and /home (e.g., nanohub)
    apps_server = Host(APPS_SERVER)
    apps_server.ssh([HOST_MERGED["FS_PATH"], 'setfacl', facl, APPS_DIR % toolname], None)
  else:
    fs = Fileserver(db, None, 0, 0)
    fs.remote('setfacl', [facl, APPS_DIR % toolname])

def resetfacl():
  """go through all the tools and set or unset facls depending on "toolaccess" value
  """
  db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)
  # first process tools, excluding the dev versions of tools
  q = """SELECT toolid, toolaccess from """ + mysql_prefix + APP_MERGED["TOOL_TABLE"] + \
      """ WHERE revision IS NOT NULL AND revision !=0 AND (state= 1 OR state = 3) AND instance not like %s"""
  rows = db.getall( q, ('%_dev'))
  seen = []
  for row in rows:
    toolaccess = row[1]
    toolid = row[0]
    seen.append(toolid)
    if toolaccess == "@OPEN":
      unsetfacl(toolid, db)
    if toolaccess == "@GROUP":
      setfacl(toolid, db)

  # process tools defined only with dev versions
  q = """SELECT toolid, toolaccess from """ + mysql_prefix + APP_MERGED["TOOL_TABLE"] + \
      """ WHERE revision IS NOT NULL AND revision !=0 AND (state= 1 OR state = 3) AND instance like %s"""
  rows = db.getall( q, ('%_dev'))
  for row in rows:
    toolaccess = row[1]
    toolid = row[0]
    if toolid in seen:
      continue
    if toolaccess == "@OPEN":
      unsetfacl(toolid, db)
    if toolaccess == "@GROUP":
      setfacl(toolid, db)

def share_session(sessnum, mode, users):
  """This command is only valid if using mirrors.  Tell the mirror to manage shared sessions
  for each user, generate a viewtoken and session cookie.  """
  db = MW_DB(mysql_host, mysql_user, mysql_password, mysql_db, MYSQL_CONN_ATTEMPTS)
  sess = Session(sessnum, session_suffix, SESSION_MERGED)
  if mode == "unshare":
    for user in users:
      sess.unshare(db, user)
    if not ZONE_SUPPORT:
      return
    z = sess.get_zone(db)
    # pass command to zone master
    if z.is_remote(db):
      z.tell(db, ["unshare_session", sessnum] + users)

  else:
    if mode == "readonly":
      sess.readonly = 'Yes'
    # if zone is local or there are no zones, find and pass display
    # otherwise pass vncpass
    if ZONE_SUPPORT:
      z = sess.get_zone(db)
      # pass command to zone master
      if z.is_remote(db):
        # grab vncpass from viewperms for session owner
        perm = []
        sess.get_vncpass(db)
        for user in users:
          viewtoken = genpasswd(32, ALPHANUM)
          cookie = genpasswd(16, ALPHANUM)
          sess.user = user
          sess.set_viewperm(db, viewtoken, cookie)
          perm.append("%s,%s,%s" %(user, viewtoken, cookie))
        z.tell(db, ["share_session", sessnum, mode] + perm)
        return
    sess.disp = Display.from_sessnum(db, sessnum)
    if sess.disp is None:
      raise MaxwellError("unable to get display for session %d" % sessnum)
    for user in users:
      viewtoken = genpasswd(32, ALPHANUM)
      cookie = genpasswd(16, ALPHANUM)
      sess.user = user
      sess.set_viewperm(db, viewtoken, cookie)

#=============================================================================
#  Support for command and argument parsing and validation.
#=============================================================================

def process_input(names, defaults, validations):
  """ Obtain and validate inputs, from either a=b arguments or through position
      Overwrite defaults for provided arguments, raise exception if mandatory argument is missing
      names is an array of input names
      defaults is a dictionary providing default values
      validations is a dictionary with appropriate regular expressions
  """
  max_arg = len(names)
  # command is argument #2 at index 1
  if len(sys.argv) > max_arg + 2:
    raise InputError("Too many arguments: %s" % " ".join(sys.argv))

  count = 0
  for arg in sys.argv[2:]:
    try:
      # argument is a=b
      (k, v) = arg.split('=')
      if k in names:
        defaults[k] = v
      else:
        raise InputError("unrecognized option: %s" % k)
    except ValueError:
      # simple positional argument
      k = names[count]
      defaults[k] = arg
    # Input validation
    m = re.match(validations[k], defaults[k])
    if m is None:
      raise InputError("Input validation failed for argument '%s'" % k)
    count += 1

  # check that we got all required arguments
  for k, v in defaults.iteritems():
    if v is None:
      raise  InputError("Missing required input '%s'" % k)

def my_call(my_func, names, values):
  """Call the specified function with the values provided.
    values is a dictionary with the arguments for the call
    names specifies the order of the keys"""
  args = []
  for a in names:
    args.append(values[a])
  start_time = time.time()
  res = globals()[my_func](*args)
  end_time = time.time()
  log("command run time: %f" % (end_time - start_time))
  sys.exit(res) # children never return here, they call os._exit()


#=============================================================================
#=============================================================================
# Main program...
#
# We recognize five distinct commands [with optional arguments]:
#   start <user> <ip> <appname> [<timeout> <version>]
#   stop <sessionnum> [<reason>]
#   view <sessionnum> <user> <ip> [<readonly>]
#
#   check <hostname> <confirm>
#   check_db  yes|no
#   notify session <sessionnum>
#   renotify (look for sessions that need to have stats processed)
#   purge (stop all ready displays because configuration changed)
#
#=============================================================================
#=============================================================================

#=============================================================================
# Configuration and Safety
#=============================================================================

# We always run here:
check_rundir()

# Load the configuration and override the default variables.
# First check that it is safe to do so
try:
  mode = os.lstat(CONFIG_FILE)[stat.ST_MODE]
except OSError:
  print "The configuration file is not readable."
  sys.exit(1)

if mode & stat.S_IWOTH:
  print "configuration file is writable by others; exiting.\n"
  sys.exit(1)

# check that user is correct
login =  pwd.getpwuid(os.geteuid())[0]

try:
  execfile(CONFIG_FILE)
except IOError:
  print "Unable to read configuration file, exiting."
  print "The configuration file '%s' needs to exist and be readable by user '%s'" % (CONFIG_FILE, login)
  sys.exit(1)

if mysql_host == 'undefined' or mysql_user == 'undefined' or mysql_password == 'undefined' or mysql_db == 'undefined':
  print "Database parameters undefined in configuration file %s, bailing out" % CONFIG_FILE
  sys.exit(1)

# check MW_USER set in configuration
if login != MW_USER:
  print "maxwell: access denied to %s. Must be run as %s (see %s)" % (login, MW_USER, CONFIG_FILE)
  sys.exit(1)

SESSION_MERGED = SESSION_K
SESSION_MERGED.update(SESSION_CONF)

# we do this to avoid having to define ZONE_SUPPORT twice in the configuration:
SESSION_MERGED.update({"ZONE_SUPPORT": ZONE_SUPPORT})

# check that SESSION_PATH exists and is a safe directory
try:
  lock_stat = os.lstat(SESSION_MERGED["LOG_PATH"])
except OSError:
  # does not exist
  print "directory %s does not exist" % SESSION_MERGED["LOG_PATH"]
  sys.exit(1)

# check that we are the owner of this dir and that others can't write
if lock_stat[stat.ST_MODE] & stat.S_IWOTH:
  print "Session lock file has incorrect permissions"
  sys.exit(1)

usr_id = lock_stat[stat.ST_UID]
if usr_id != os.geteuid():
  print "Session lock file has incorrect owner: %s" % usr_id
  sys.exit(1)

save_out() # this script needs ttyprint functionality
setup_log(MASTER_LOG_FILENAME, LOG_ID)

try:
  os.environ["HOME"]
except KeyError:
  os.environ["HOME"]=WEB_HOMEDIR

HOST_MERGED = HOST_K
HOST_MERGED.update(HOST_CONF)
APP_MERGED = APP_K
APP_MERGED.update(APP_CONF)

#=============================================================================
# Input parsing
#=============================================================================

try:
  # Verify that we have at least a command
  if len(sys.argv) < 2:
    raise InputError("No command given")

  if VERBOSE:
    log("received command '%s'" % " ".join(sys.argv))

  if sys.argv[1] == "start":
    order = ['user', 'ip', 'app', 'timeout', 'version', 'appopts', 'params', 'zone', 'template', 'confbase']
    defaults = {
      'user':None,
      'ip':None,
      'app':None,
      # timeout is set to the value specified by the app no matter what the input is.
      'timeout':0,
      'version':default_version,
      'appopts':"",
      'params':"",
      'zone':"",
      'template':"",
      'confbase':""
    }
    validations = {
      'user':USER_REGEXP,
      'ip':IP_REGEXP,
      'app':PATH_REGEXP,
      'timeout':INT_REGEXP,
      'version':NAME_REGEXP,
      'appopts':QUOTED_REGEXP,
      'params':URL_REGEXP,
      'zone':NAME_REGEXP,
      'template':NAME_REGEXP,
      'confbase':NAME_REGEXP
    }
    process_input(order, defaults, validations)
    if defaults['version'] == 'default':
      defaults['version'] = default_version
    defaults['timeout'] = int(defaults['timeout'])

    my_call("create_session", order, defaults)

  elif sys.argv[1] == "stop":
    order = ['sessnum', 'reason']
    defaults = {
      'sessnum':None,
      'reason':'user'
    }
    validations = {
      'sessnum':INT_REGEXP,
      'reason' :NAME_REGEXP
    }
    process_input(order, defaults, validations)
    defaults['sessnum'] = int(defaults['sessnum'])
    my_call("stop_session_background", order, defaults)

  elif sys.argv[1] == "view":
    #   view <session_number> <user> <ip> [<readonly>]
    order = ['sess', 'user', 'ip', 'readonly']
    defaults = {
      'sess':None,
      'user':None,
      'ip':None,
      'readonly':0
    }
    validations = {
      'sess': INT_REGEXP, # expecting number
      'user': USER_REGEXP,
      'ip'  : IP_REGEXP,
      'readonly':INT_REGEXP
    }
    process_input(order, defaults, validations)

    defaults['sess'] = int(defaults['sess'])
    my_call("view_applet", order, defaults)

  elif sys.argv[1] == "check":
    order = ['hostname', 'confirm']
    defaults = {
      'hostname':None,
      'confirm': None
    }
    validations = {
      'hostname':PATH_REGEXP,
      'confirm':'\Ayes|no\Z'
    }
    process_input(order, defaults, validations)
    my_call("check_host", order, defaults)

  elif sys.argv[1] == "check_db":
    order = ['make_changes']
    defaults = {
      'make_changes':False
    }
    validations = {
      'make_changes':'\Ayes|no\Z'
    }
    process_input(order, defaults, validations)
    my_call("check_db", order, defaults)

  elif sys.argv[1] == "renotify":
    order = []
    defaults = {
    }
    validations = {
    }
    process_input(order, defaults, validations)
    my_call("renotify", order, defaults)

  elif sys.argv[1] == "shutdown":
    order = []
    defaults = {
    }
    validations = {
    }
    process_input(order, defaults, validations)
    my_call("shutdown", order, defaults)

  elif sys.argv[1] == "purge":
    order = ['hostname']
    defaults = {
      'hostname':"",
    }
    validations = {
      'hostname':PATH_REGEXP,
    }
    process_input(order, defaults, validations)
    my_call("purge", order, defaults)

  elif sys.argv[1] == "bork":
    order = []
    defaults = {
    }
    validations = {
    }
    process_input(order, defaults, validations)
    my_call("bork", order, defaults)

  elif sys.argv[1] == "verify":
    order = []
    defaults = {
    }
    validations = {
    }
    process_input(order, defaults, validations)
    my_call("verify", order, defaults)

  elif sys.argv[1] == "notify":
    """ The notify command is called from exechosts, who have been given an ssh key
      which we refer to as a notify key.  The public key is stored here in the authorized keys
      ssh file.  The private key is given (scped) to the execution host when executing the command
      check <hostname> <confirm>
      Because of this, the reverse DNS check ought to be safe (assuming our DNS servers are OK)

      Also, there is a forced command attached that specifies the "notify" command.  So all uses of
      this key should end up here.

    """
    fromhost = ""
    try:
      conn = os.environ["SSH_CONNECTION"]
      fromhost = conn.split()[0]
    except (KeyError, AttributeError, TypeError):
      pass

    try:
      # Need to read from SSH_ORIGINAL_COMMAND instead of argv, due to the forced command
      ssh_cmd = os.environ["SSH_ORIGINAL_COMMAND"].split()
      if len(ssh_cmd) != 3:
        raise InputError("Incorrect number of arguments, expecting 'notify session <sessionnum>'")
      if ssh_cmd[0] != "notify" or ssh_cmd[1] != "session":
        raise InputError("Bad notify command sent from %s" % (fromhost))
      else:
        sessname = ssh_cmd[2]
        sessnumber = int(sessname)
    except TypeError:
      raise InputError("Syntax error")
    except KeyError:
      # no key "SSH_ORIGINAL_COMMAND";  attempt to use sys.argv
      # this must be a manual invocation for debugging
      if len(sys.argv) != 4:
        raise InputError("Incorrect number of arguments")
      if sys.argv[2] != "session":
        raise InputError("Bad notify command")
      sessnumber = int(sys.argv[3])
      sys.exit(session_exit_notify(sessnumber, fromhost))
    log("Notified %d from %s (%s)" % (sessnumber, fromhost, dns_reverse(fromhost)[0:99]))
    sys.exit(session_exit_notify(sessnumber, fromhost))

  elif sys.argv[1] == "update_quota":
    order = ['user', 'block_soft', 'block_hard']
    defaults = {
      'user':None,
      'block_soft':None,
      'block_hard':None
    }
    validations = {
      # accept comma-separated list of users
      'user': r'\A[0-9a-zA-Z]+[_0-9a-zA-Z\.\,]*\Z',
      'block_soft': INT_REGEXP,
      'block_hard':INT_REGEXP
    }
    process_input(order, defaults, validations)
    my_call("update_quota", order, defaults)

  elif sys.argv[1] == "get_quota":
    order = ['user']
    defaults = {
      'user':None
    }
    validations = {
      'user': USER_REGEXP
    }
    process_input(order, defaults, validations)
    my_call("get_quota", order, defaults)

  elif sys.argv[1] == "screenshot":
    order = ['user']
    defaults = {
      'user':None
    }
    validations = {
      'user': USER_REGEXP
    }
    process_input(order, defaults, validations)
    my_call("screenshot", order, defaults)

  elif sys.argv[1] == "setfacl":
    order = ['toolid']
    defaults = {
      'toolid':None
    }
    validations = {
      'toolid': INT_REGEXP
    }
    process_input(order, defaults, validations)
    my_call("setfacl", order, defaults)

  elif sys.argv[1] == "unsetfacl":
    order = ['toolid']
    defaults = {
      'toolid':None
    }
    validations = {
      'toolid': INT_REGEXP
    }
    process_input(order, defaults, validations)
    my_call("unsetfacl", order, defaults)

  elif sys.argv[1] == "resetfacl":
    my_call("resetfacl", [], {})

  elif sys.argv[1] == "create_userhome":
    order = ['user']
    defaults = {
      'user':None
    }
    validations = {
      'user': USER_REGEXP
    }
    process_input(order, defaults, validations)
    my_call("create_userhome", order, defaults)

  elif sys.argv[1] == "erase_userhome":
    order = ['user']
    defaults = {
      'user':None
    }
    validations = {
      'user': USER_REGEXP
    }
    process_input(order, defaults, validations)
    my_call("erase_userhome", order, defaults)

  elif sys.argv[1] == "resize":    
    order = ['sessnum', 'geom']
    defaults = {
      'sessnum':None,
      'geom':None
    }
    validations = {
      'sessnum':INT_REGEXP,
      'geom' :GEOM_REGEXP
    }
    process_input(order, defaults, validations)
    my_call("resize", order, defaults)

  elif sys.argv[1] == "move_userhome":
    order = ['userfrom', 'userto']
    defaults = {
      'userfrom':None,
      'userto':None
    }
    validations = {
      'userfrom': USER_REGEXP,
      'userto': USER_REGEXP
    }
    process_input(order, defaults, validations)
    my_call("move_userhome", order, defaults)

  elif sys.argv[1] == "restore_firewall":
    order = []
    defaults = {
    }
    validations = {
    }
    process_input(order, defaults, validations)
    my_call("restore_firewall", order, defaults)

  elif sys.argv[1] == "share_session":
    # expected share_session sessnum mode user1 user2 user3
    if len(sys.argv) < 4:
      raise InputError("Incomplete command: '%s'" % " ".join(sys.argv))
 
    # session number 
    sessnum = int(sys.argv[2])

    # mode
    if sys.argv[3] != "readonly" and sys.argv[3] != "readwrite" and sys.argv[3] != "unshare":
      raise InputError("Bad mode '%s', needs to be one of readonly, readwrite or unshare" % sys.argv[i])
    mode = sys.argv[3]

    # unlimited users into an array
    prog = re.compile(USER_REGEXP)
    try:
      users = map(lambda x:prog.match(x).group(0), sys.argv[4: len(sys.argv)])
    except AttributeError:
      raise InputError("Invalid command: '%s'" % " ".join(sys.argv))

    share_session(sessnum, mode, users)

  elif sys.argv[1] == "version":
    ttyprint("October 2016")

  else:
    m = re.match(LOGSAFE_REGEXP, sys.argv[1])
    if m is None:
      raise InputError("Unknown command that is unsafe to log")
    # limit size of log entry to 99
    raise InputError("Unknown command: '%s'" % m.group()[0:99])

# attempted conversion of alpha chars to int results in ValueError
except ValueError, e:
  print_n_log("Integer input expected for this command: '%s'" % sys.argv[1])
  if VERBOSE:
    log_exc(e)
  sys.exit(1)

except InputError, e:
  print_n_log("%s" % e)
  if VERBOSE:
    log_exc(e)
  sys.exit(1)

except PublicError, e:
  # display everything to user
  print_n_log("%s" % e)
  if VERBOSE:
    log_exc(e)
  sys.exit(3)

except MaxwellError, e:
  # prevent raw error from being displayed to user
  # see CWE-209 	"Information Exposure Through an Error Message"
  log("%s" % e)
  if VERBOSE:
    log_exc(e)
  user_print("A serious internal error occurred. Exiting.")
  sys.exit(2)

except MySQLdb.MySQLError, e:
  log("MySQL Error: %s" % e)
  if VERBOSE:
    log_exc(e)
  user_print("Database error logged, exiting.")
  sys.exit(4)

except Exception, e:
  # prevent raw error from being displayed to user
  # see CWE-209 "Information Exposure Through an Error Message"

  if VERBOSE:
    log_exc(e)
  log("Fatal error %s" % e)
  user_print("Fatal error logged, exiting.")
  sys.exit(5)

