#!/usr/bin/python
#
# @package      hubzero-mw-service
# @file         maxwell_service
# @author       Rick Kennell <kennell@purdue.edu>
# @author       Nicholas J. Kisseberth <nkissebe@purdue.edu>
# @copyright    Copyright (c) 2005-2015 HUBzero Foundation, LLC.
# @license      http://opensource.org/licenses/MIT MIT
#
# Copyright (c) 2005-2015 HUBzero Foundation, LLC.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# HUBzero is a registered trademark of HUBzero Foundation, LLC.
#

import os
import sys
import signal
import stat
import time
import socket
import subprocess
import string
import tempfile
import pwd
import grp
import shutil
import re

#=============================================================================
# Set up default parameter values.
#=============================================================================
allow_fuse=0
if os.path.exists("/etc/debian_version"):
  notify_account="www-data"
elif os.path.exists("/etc/redhat-release"):
  notify_account="apache"
notify_retries=1440
notify_timeout=60
notify_hosts='localhost'
viewer_check_period=10
logfile="/var/log/mw-service/service.log"
ovz_session_conf="hub-session-7.0-amd64.conf"
ovz_session_mount="hub-session-7.0-amd64.mount"
ovz_session_umount="hub-session-7.0-amd64.umount"
xauth_retries=3
vzoffset = 0
machine_number=0 # Every OpenVZ host should have a unique number.
portbase=5000
network_group="network"
submit_server=""
mount_projects=False
project_path=None

#=============================================================================
# Load the configuration and override the parameters above.
#=============================================================================
try:
  execfile('/etc/mw-service/mw-service.conf')
except IOError:
  pass

if (machine_number == 0):
  try:
    machine_number = int(socket.gethostbyname(socket.getfqdn()).split('.')[3])
  except:
    print "Unable to get IP address.  Setting machine_number to 0."
    machine_number = 0

if notify_hosts != '':
  notify_hosts = notify_hosts.split(',')

#=============================================================================
# Set up errors to go to the log file.
#=============================================================================
def openlog():
  if not os.isatty(1):
    try:
      log = open(logfile,"a+")
      os.dup2(log.fileno(), sys.stdout.fileno())
      os.dup2(log.fileno(), sys.stderr.fileno())
    except:
      pass

#=============================================================================
# Log a message.
#=============================================================================
def log(msg):
  if (os.isatty(1)):
    sys.stdout.write(msg + "\n")
    sys.stdout.flush()
  else:
    sys.stderr.write(msg + "\n")
    sys.stderr.flush()

#=============================================================================
# Convert a list to a string.
#=============================================================================
def l2s(l):
  s = ""
  for x in l:
    s += " " + str(x)
  return s

#=============================================================================
# Tell the caller that a session is finished.
#=============================================================================
def notify_command_finished(sessnum,host):

  account = notify_account + "@" + host

  (pid,fd) = os.forkpty()
  if pid < 0:
    log("Could not fork process: %d" % pid)
    sys.exit(1)
  elif pid == 0:
    status = os.system("""ssh -i /etc/mw-service/notify.key %s notify session %s""" % (account, sessnum))
    if status != 0:
      log("Unable to exec ssh")
      sys.exit(1)
    sys.exit(0)

  log("Notifying %s about session %s" % (account,sessnum))

  def alarmHandler(sig,frame):
    raise IOError, "Timeout"
    sys.exit(1)
  signal.signal(signal.SIGALRM, alarmHandler)
  signal.alarm(notify_timeout)

  try:
    line = os.read(fd,1000)
  except IOError:
    os.kill(9,pid)
    log("Unable to notify %s about %s" % (account,sessnum))
    sys.exit(1)
  except OSError:
    # Process exited already.
    line=''
    pass

  if line.startswith("The authenticity of host"):
    os.write(fd,"yes\n")
    os.read(fd,1000)
  elif len(line) != 0:
    log("notify_command_finished: " + line)
    sys.exit(1)

  try:
    (pid, status) = os.waitpid(pid, 0)
    if status == 0:
      return True
    else:
      return False
  except:
    return False

#=============================================================================
# Check whether a port is in use.
#=============================================================================
def check_port(port):
  cmd="netstat -tn | grep ESTABLISHED | sed 's,^\([^ ]*[ ]*[^ ]*[ ]*[^ ]*[ ]*\)\([^ ]*\).*$,\\2,' | grep -q ':%d$'" % port
  #print "Command is: %s" % cmd
  status = os.system(cmd)
  if status == 0:
    return 1
  else:
    return 0

#=============================================================================
# Move the session directory to a new name to indicate it's expired.
#=============================================================================
def expire_session_dir(user,id):
  for iter in range(1,100):
    #log("expire_session_dir iter %d\n" % iter)
    try:
      info = pwd.getpwnam(user)
      uid = info[2]
      gid = info[3]
      homedir = info[5]
    except KeyError:
      log("Expire_session_dir: can't find account information for '%s'\n." % user)
      continue

    olddir = homedir + "/data/sessions/%s" % id
    newdir = homedir + "/data/sessions/%s-expired" % id

    delcmd = """su %s -c 'mv %s %s'""" % (user,olddir,newdir)
    log("Executing %s\n" % delcmd)

    pid = os.fork()
    if pid < 0:
      log("expire_session_dir: unable to fork\n")
      sys.exit(1)
    elif pid == 0:
      # Child:
      # Note: make sure that this environment reset occurs only within
      # the child process.  The parent still needs to know a few environment
      # variables and we're effectively clearing them all here.
      os.environ={}
      env={
        "HOME":homedir,
        "LOGNAME":user,
        "PATH":"/bin:/usr/bin:/usr/bin/X11:/sbin:/usr/sbin",
        "USER":user
      }
      os.execve("/bin/sh", ["/bin/sh",'-c',delcmd], env)
      log("Unable to execve.\n")
      sys.exit(1)
    else:
      # Parent:
      #(child,status)=os.waitpid(pid,0)
      return

#=============================================================================
# Get the list of hosts to try to notify when the command exits.
#=============================================================================
def get_notify_hosts():
  if notify_hosts != []:
    return notify_hosts

  try:
    conn=os.environ["SSH_CONNECTION"]
    host=conn.split()[0]
    return [ host ]
  except KeyError:
    log("notify_command_finished: can't find SSH_CONNECTION")
    sys.exit(1)

#=============================================================================
# Start an X application
#=============================================================================
def startxapp(user,id,timeout,command,disp):
    #
    # Decide whether this should be 'notify' protocol or 'stream' protocol.
    # If not explicitly specified, assume it's the original 'stream' version.
    stream=True
    notify=False
    if command[0:7] == "stream ":
      command = command[7:]
      stream=True
    elif command[0:7] == "notify ":
      command = command[7:]
      notify=True
      stream=False

    # For stream protocol, wait for the command to exit or timeout.
    if stream:
      status = invoke_unix_command(user,id,timeout,command,disp)
      sys.exit(status)

    # Everything below is for the notify protocol.

    # Get the session number by removing the optional suffix
    sessnum='x'
    if id[-1] >= '0' and id[-1] <= '9':
      sessnum=id
    else:
      sessnum=id[0:-1]

    # First fork().
    try:
      pid = os.fork()
      if pid > 0:
        sys.exit(0)
    except OSError, e:
      log("fork #1 failed: %d (%s)" % (e.errno, e.strerror))
      sys.exit(1)

    # This is the child.  Dissociate from the parent.
    os.setsid()
    input=os.open("/dev/null", os.O_RDONLY)
    output=os.open("""/var/log/mw-service/open-sessions/%s.out""" % sessnum, os.O_WRONLY|os.O_CREAT)
    error=os.open("""/var/log/mw-service/open-sessions/%s.err""" % sessnum, os.O_WRONLY|os.O_CREAT)
    os.dup2(input,0)
    os.dup2(output,1)
    os.dup2(error,2)
    for i in range(3,1024):
      try:
        os.close(i)
      except OSError:
        pass

    # Second fork()
    try:
      pid = os.fork()
      if pid > 0:
        sys.exit(0)
    except OSError, e:
      log("fork #2 failed: %d (%s)" % (e.errno, e.strerror))
      sys.exit(1)

    # At this point, we're the child of a child and dissociated from
    # the original session.  Invoke the command, wait for it to exit,
    # and then notify the remote caller.
    invoke_unix_command(user,id,timeout,command,disp)
    output=os.open(logfile, os.O_WRONLY|os.O_CREAT|os.O_APPEND)
    os.dup2(output,1)
    os.dup2(output,2)
    expire_session_dir(user,id)
    for i in range(0,notify_retries):
      for host in get_notify_hosts():
        status = notify_command_finished(sessnum,host)
        if status:
          return
        time.sleep(1)
    log("Unable to notify controller about session %s" % sessnum)

#=============================================================================
# Print statistics for an OpenVZ VPS.
#=============================================================================
def printvzstats(veid):
  f = open("/proc/vz/vestat")
  while 1:
    line = f.readline()
    if line == "":
      break
    arr = line.split()
    print arr
    if len(arr) < 5:
      continue
    sys.stderr.write("Checking /proc/vz/vestat veid = %s\n" % arr[0])
    try:
      if int(arr[0]) == veid:
        # Since VEs are pre-created, this is NOT the real time.
        # Let the middleware host compute this time.
        #sys.stderr.write("real %f\n" % (int(arr[4])/1000.0))
        sys.stderr.write("user %f\n" % (int(arr[1])/1000.0))
        sys.stderr.write("sys %f\n" % (int(arr[3])/1000.0))
        sys.stderr.flush()
        break
    except:
      pass
  f.close()

def root_mount(vz_root_path, point, perm):
  """Mount a directory to make it available to OpenVZ containers.  By mounting under root,
  containers can't modify the original. """
  mntpt = vz_root_path + point
  if os.path.isdir(mntpt):
    try:
      os.rmdir(mntpt)
    except OSError, exc:
      log("exception:'%s'\n" % exc)
      raise
  os.mkdir(mntpt)
  log("Created %s" % mntpt)

  # -n: Mount without writing in /etc/mtab.
  args = ["/bin/mount", "-n", "--bind", point, '-o', perm, mntpt]
  p = subprocess.Popen(args, stderr=subprocess.PIPE, stdout=subprocess.PIPE)
  p.communicate()
  if p.returncode != 0:
    try:
      os.rmdir(mntpt)
    except OSError:
      pass
    raise 

#=============================================================================
# Start the command running and watch the view time.
#=============================================================================
def invoke_unix_command(user,id,timeout,command,disp):
  #
  # Child will invoke the command.
  # Parent will handle the timeout.
  #
  log("Starting command '%s' for '%s' with timeout '%s'" % (command,user,timeout))
  port = 0
  if disp != 0:
    port = 5000 + disp
    log("Checking port %d" % port)

  tick = viewer_check_period

  pid=os.fork()
  if pid < 0:
    print "Unable to fork job for '%s'" % command
    sys.exit(1)

  if pid == 0:
    try:
      info = pwd.getpwnam(user)
      uid = info[2]
      gid = info[3]
      homedir = info[5]
    except KeyError:
      print "Unable to find account information for '%s'." % user
      sys.exit(1)
                
    os.environ={}
    env={
      "DISPLAY":":%d" % disp,
      "HOME":homedir,
      "LOGNAME":user,
      "PATH":"/bin:/usr/bin:/usr/bin/X11:/sbin:/usr/sbin",
      "SESSION":id,
      "SESSIONDIR":homedir + "/data/sessions/%s" % id,
      "TIMEOUT":str(timeout),
      "USER":user
    }

    if os.path.isfile(homedir + "/data/sessions/%s" % id + "/parameters.hz"):
      env["TOOL_PARAMETERS"] = homedir + "/data/sessions/%s" % id + "/parameters.hz"

    veid = disp + vzoffset
    veaddr="10.%d.%d.%d" % (machine_number, veid/100, veid % 100)

    try:
      info = pwd.getpwnam(user)
      uid = info[2]
      gid = info[3]
      homedir = info[5]
      shell = info[6]
    except KeyError:
      print "Unable to find account information for '%s'." % user
      sys.exit(1)

    # env["SHELL"]=shell

    os.system("""echo "%s:!:%d:%d::%s:%s" >> /var/lib/vz/root/%d/etc/passwd"""
                    % (user,uid,gid,homedir,shell,veid))
    os.system("""echo "%s:!:13281:0:99999:7:::" >> /var/lib/vz/root/%d/etc/shadow"""
                    % (user,veid))

    # Get a list of the supplimentary groups...
    f=os.popen("id %s |sed -e 's/^.*groups=//' -e 's/[0-9]*[(]\([^)]*\)[)]/\\1/g'" % user)
    groups=f.read()
    groups=groups.strip()
    f.close()

    log("ADDING THE FOLLOWING GROUPS: %s" % groups)
    # Build it in to the /etc/group file in the VEID...
    allow_network=0
    for g in groups.split(','):
      try:
        gname,gpw,gid,mem=grp.getgrnam(g)
        os.system("""echo "%s:x:%d:%s" >> /var/lib/vz/root/%d/etc/group""" % (gname,gid,user,veid))
        log("ADDED GROUP: %s" % gname)
        if g == network_group:
          allow_network = 1
        if g == 'apps':
          log("ADD 'apps' USER AND 'apps' GROUP SUDOERS RULE")
          try:
            appsinfo = pwd.getpwnam('apps')
          except KeyError:
            log("Unable to find account information for 'apps'.")
            sys.exit(1)
          try:
            os.system("""echo "%%apps		ALL=NOPASSWD:/bin/su - apps" >> /var/lib/vz/root/%d/etc/sudoers""" % (veid))	    
            os.system("""echo "%s:!:%d:%d::%s:%s" >> /var/lib/vz/root/%d/etc/passwd""" % ('apps',appsinfo[2],appsinfo[3],appsinfo[5],appsinfo[6],veid))
            os.system("""echo "%s:!:13281:0:99999:7:::" >> /var/lib/vz/root/%d/etc/shadow""" % ('apps',veid))
            os.system("""mkdir -m 0740 /var/lib/vz/root/%d/var/apps""" % (veid))
            os.system("""chown apps:apps /var/lib/vz/root/%d/var/apps""" % (veid))
          except:
            log("Unable to create 'apps' user or groups or sudoers rule.")
            sys.exit(1)
      except:
        log("Unable to create all groups user is a member of.")
        pass

    if allow_fuse:
      os.system(r'/bin/sed -i "/^fuse:/s/\(.*\)/\1,%s/;s/:,/:/" /var/lib/vz/root/%d/etc/group' % (user,veid))

    if mount_projects:
      vz_root_path = "/var/lib/vz/root/%d" % (veid)

      for g in groups.split(','):
        try:
          gname,gpw,gid,mem=grp.getgrnam(g)
          if gname[0:3] == "pr-":
            source_mount = project_path + gname[3:]
            if not os.path.exists(source_mount):
              continue
            if not os.path.exists(vz_root_path + source_mount):
              args = ['/bin/mkdir', '-m', '0700', '-p', vz_root_path + source_mount]
              p = subprocess.Popen(args, stderr=subprocess.PIPE, stdout=subprocess.PIPE)
              p.communicate()
              if p.returncode != 0:
                sys.exit(1)
            root_mount(vz_root_path, source_mount, 'rw,noatime')
        except Exception as e:
          log("Unable to mount all projects user is a member of. %s" % e)
          pass

    parts = submit_server.rsplit(':',1)

    if len(parts) == 1:
        parts.append('830')

    # @TODO this needs to be expanded and made more robust:
    #     implement in python
    #     read file
    #     populate with defaults if empty
    #     conditionally add section header
    #     conditionally add/replace listenURIs line
    #     write file
    os.system(r'/bin/sed -i "s/^\s*listenURIs.*$//g" /var/lib/vz/root/%d/etc/submit/submit-client.conf' % (veid))
    os.system(r'echo -e "\nlistenURIs = tcp://%s:%s\n" >> /var/lib/vz/root/%d/etc/submit/submit-client.conf' % (parts[0],parts[1],veid))

    veaddr="10.%d.%d.%d" % (machine_number, veid/100, veid % 100)
    os.system("/sbin/iptables -t nat -D PREROUTING -i venet0 -s %s -j ACCEPT" % veaddr)
    os.system("/sbin/iptables -D FORWARD -i venet0 -s %s -j ACCEPT" % veaddr)
    if allow_network:
      os.system("/sbin/iptables -t nat -I PREROUTING -i venet0 -s %s -j ACCEPT" % veaddr)
      os.system("/sbin/iptables -I FORWARD -i venet0 -s %s -j ACCEPT" % veaddr)

    count=0
    while os.system("""vzctl exec2 %d su %s -s /bin/sh -c /usr/bin/mergeauth""" % (veid,user))!=0:
      time.sleep(0.5)
      count=count+1
      if count > xauth_retries:
        log("Unable to extract xauth cookie for %d" % veid)
        sys.exit(1)
 
    env["DISPLAY"]="%s:0.0" % veaddr

    envcmd = command
    for k in env:
      envcmd = """%s="%s" %s""" % (k, env[k], envcmd)
    envcmd = "cd; %s" % envcmd
    #
    # In the command below, use "time" to get the runtime of the command.
    # The user and sys cputimes will be inaccurate, but will be overridden
    # by printvzstats().
    #
    everything="""time -p vzctl exec2 %d "su %s -s /bin/sh -c '%s'" """ % (veid,user,envcmd)
    os.write(2,"command is %s\n" % everything)
    status = os.system(everything)
    os.write(2,"command finished\n")

    printvzstats(veid)

    # When we're done with the command, we're done with this VE.  Stop it.
    stopvnc(disp)
    veaddr="10.%d.%d.%d" % (machine_number, veid/100, veid % 100)
    os.system("/sbin/iptables -t nat -D PREROUTING -i venet0 -s %s -j ACCEPT" % veaddr)
    os.system("/sbin/iptables -D FORWARD -i venet0 -s %s -j ACCEPT" % veaddr)
    sys.exit(status)

  try:
    vncdir="/var/run/mw-service/"
    try:
      result = os.lstat(vncdir)
    except OSError:
      try:
        os.umask(0077)
        os.makedirs(vncdir)
      except OSError:
        print "Unable to create %s." % vncdir
        sys.exit(1)

    pidfile=open("/var/run/mw-service/pid.%d" % disp, "w")
    pidfile.write("%d" % pid)
    pidfile.close()
  except IOError:
    print "Unable to open pidfile for session.  Running without one."

  if timeout > 0 and port > 0:
    idle=0
    viewsum=0.0
    def alarmHandler(sig, frame):
      raise IOError, "Timeout"
    signal.signal(signal.SIGALRM, alarmHandler)
    signal.alarm(tick)

    while 1:
      try:
        (child,status)=os.waitpid(pid,0)
        sys.stderr.write("view %f\n" % viewsum)
        sys.stderr.write("Exit_Status: %d\n" % status)
        return 0
      except OSError:
        return 1
      except IOError:
        if check_port(port) > 0:
          #print "Found listener: idle is now 0"
          idle = 0
          viewsum += tick
        else:
          idle += tick
          #print "No listener: idle is now %d, timeout is %d" % (idle,timeout)
          if idle > timeout:
            printvzstats(disp)
            stopvnc(disp)
            sys.stderr.write("Timeout_Value: %d\n" % timeout)
            sys.stderr.write("view %f\n" % viewsum)
        signal.alarm(tick)

  else:
    try:
      print "Waiting for %d" % pid
      (child,status) = os.waitpid(pid,0)
      sys.stderr.write("Exit_Status: %d\n" % status)
    except OSError:
      pass
    return 0

#=============================================================================
# Check the user's directory
#=============================================================================

def relinquish(uid, gid):
  """Relinquish elevated privileges;  adopt user identity for file operations"""
  try:
    os.setregid(gid,gid)
    os.setreuid(uid,uid)
  except OSError:
    log("unable to change uid and gid for '%s'" % user)
    sys.exit(1)
    
  still_root = True
  try:
    os.setreuid(0,0)
  except OSError:
    still_root = False
  if still_root:
    log("Was able to revert to root!")
    sys.exit(1)

def create_userhome(user):
  #print "Setting up directory for '%s'" % (user)

  try:
    info = pwd.getpwnam(user)
    uid = info[2]
    gid = info[3]
    homedir = info[5]
  except KeyError:
    print "Unable to find account information for '%s'." % user
    sys.exit(1)

  try:
    os.stat(homedir)
  except OSError:
    # need to create home directory using root privileges
    os.mkdir(homedir, 0700)
    os.chown(homedir,uid,gid)

def setup_dir(user,session,params=''):
  #print "Setting up directory for '%s', session '%s'." % (user,session)

  try:
    info = pwd.getpwnam(user)
    uid = info[2]
    gid = info[3]
    homedir = info[5]
  except KeyError:
    print "Unable to find account information for '%s'." % user
    sys.exit(1)

  os.umask(0077) # only user has permissions

  try:
    os.stat(homedir)
  except OSError:
    # need to create home directory using root privileges
    os.mkdir(homedir, 0700)
    os.chown(homedir,uid,gid)
    os.system("/usr/bin/set_quotas %s" % user)

  relinquish(uid, gid)

  datadir = homedir + "/data"
  sessions = datadir + "/sessions"
  sessdir = sessions + "/%s" % session
  resources = sessdir + "/resources"
  try:
    try:
      os.stat(datadir)
    except OSError:
      os.mkdir(datadir)
    try:
      os.stat(sessions)
    except OSError:
      os.mkdir(sessions)
    
    session_existed = True
    try:
      os.stat(sessdir)
    except OSError:
      os.mkdir(sessdir)
      session_existed = False
    if session_existed:
      log("session directory already existed (%s, %s)" % (user, session))
      sys.exit(1)
      
    rfile=open(resources,"w")
    rfile.write("sessionid %s\n" % session)
    rfile.write("results_directory %s/data/results/%s\n" % (homedir,session))
    rfile.close()

    if params is not None and params != "":
      import urllib2
      params_path = sessdir + "/parameters.hz"
      pfile = open(params_path, "w")
      value = urllib2.unquote(params).decode("utf8")
      pfile.write(str(value))
      pfile.close()

  except OSError:
    log("Unable to setup user '%s' in session %s" % (user, session))
    sys.exit(1)

def update_quota(user, block_soft, block_hard):
  """
    Change the quota for a user.
    This happens on a fileserver.
  """
  if block_soft < 0 or block_hard < 0:
    raise InputError("Invalid quotas")

  os.system("/usr/sbin/setquota -a %s %d %d 0 0" % (user, int(block_soft), int(block_hard)))

#=============================================================================
# Update the resource file for the session.
#=============================================================================
def update_resources(user,session):
  #print "Updating resources for '%s', session '%s'." % (user,session)

  try:
    info = pwd.getpwnam(user)
    uid = info[2]
    gid = info[3]
    homedir = info[5]
  except KeyError:
    print "Unable to find account information for '%s'." % user
    sys.exit(1)

  datadir = homedir + "/data"
  sessions =datadir + "/sessions"
  sessdir = sessions + "/%s" % session
  resources = sessdir + "/resources"
  os.umask(0077)
  relinquish(uid, gid)

  try:
    rfile=open(resources,"a+")
    while 1:
      line = sys.stdin.readline()
      if line == "":
        break
      rfile.write(line)

    rfile.close()
  except OSError:
    print "Unable to append to resource file."
    sys.exit(1)

#=============================================================================
# Kill a tree of processes in a depth-first manner.
#=============================================================================
def killtree(pid):
  proclist = {}
  list = os.listdir("/proc")
  for file in list:
    try:
      ipid=int(file)
    except ValueError:
      continue

    try:
      fh = open("/proc/" + file + "/status")
    except IOError:
      continue

    ppid=0
    while 1:
      try:
        line = fh.readline()
      except IOError:
        break
      #print "line is ", line
      arr=string.split(line)
      if arr[0] == "PPid:":
        #line.replace(' ','')
        #arr = line.split(':')
        ppid = int(arr[1])
        if ppid == 0:
          break

      if arr[0] == "Uid:":
        for x in range(1,len(arr)):
          if int(arr[x]) == 0:
            ppid=0
        break

    if ppid != 0:
      #print "%d has parent %d" % (ipid,ppid)
      if proclist.has_key(ppid):
        proclist[ppid].append(ipid)
      else:
        proclist[ppid] = [ ipid ]

    fh.close()

  # Inner function to do some printing.
  def doprint(indent,proclist,pid):
    print "  "*indent, pid
    if proclist.has_key(pid):
      for child in proclist[pid]:
        doprint(indent+1,proclist,child)

  #doprint(0,proclist,pid)

  # Inner function to do the killing.
  def dokill(proclist,pid):
    if proclist.has_key(pid):
      for child in proclist[pid]:
        dokill(proclist,child)
    sig=0
    nextsig={-1:-1, 0:1, 1:15, 15:9, 9:-1}
    while 1:
      try:
        fd=os.open("/proc/%d/status" % pid, os.O_RDONLY)
      except OSError:
        break
      os.close(fd)
      sig=nextsig[sig]
      if sig==-1:
        return
      print "Killing %d with %d" % (pid,sig)
      os.kill(pid,sig)
      time.sleep(0.5)

  dokill(proclist,pid)

#=============================================================================
# Kill all processes belonging to a particular user.
#=============================================================================
def killall(user):
  #print "Terminating all processes belonging to user '%s'" % user
  try:
    info = pwd.getpwnam(user)
    uid = info[2]
    gid = info[3]
    homedir = info[5]
  except KeyError:
    print "Unable to find account information for '%s'." % user
    sys.exit(1)

  list = os.listdir("/proc")
  for file in list:
    try:
      ipid=int(file)
    except ValueError:
      continue

    try:
      fh = open("/proc/" + file + "/status")
    except IOError:
      continue

    while 1:
      try:
        line = fh.readline()
      except IOError:
        break
      #print "line is ", line
      arr=string.split(line)

      if arr[0] == "Uid:":
        ok=1
        for x in range(1,len(arr)):
          if int(arr[x]) != uid:
            ok=0
        if ok == 1:
          os.kill(ipid,9)
        break

    fh.close()

  sys.exit(0)

#=============================================================================
# Start a VNC server...
#=============================================================================
def startvnc(disp,geom,depth):

  vncdir="/var/run/mw-service/"
  try:
    result = os.lstat(vncdir)
  except OSError:
    try:
      os.umask(0077)
      os.makedirs(vncdir)
    except OSError:
      print "Unable to create %s." % vncdir
      sys.exit(1)

  try:
    print "Reading passphrase:"
    passwd=sys.stdin.read(8)
    passfile=open(vncdir + "pass.%d" % disp, "w")
    passfile.write(passwd)
    passfile.close()
  except OSError:
    print "Failed to create password files."
    sys.exit(1)

  veid = disp + vzoffset
  veaddr="10.%d.%d.%d" % (machine_number, veid/100, veid % 100)

  #
  # Sanitize the environment just in case.
  #
  os.system("VEID=%d /etc/vz/conf/%s" % (veid,ovz_session_umount))

  try:
    result = 0
    result = os.lstat("/var/lib/vz/root/%d" % veid)
  except OSError:
    pass
  if result != 0:
    log("/var/lib/vz/root/%d already exists.  Ditching it." % veid)
    os.system("VEID=%d /etc/vz/conf/%s" % (veid,ovz_session_umount))
    os.rmdir("/var/lib/vz/root/%d" % veid);
    try:
      os.lstat("/var/lib/vz/root/%d" % veid)
      log("/var/lib/vz/root/%d still exists.  Giving up." % veid)
      sys.exit(1)
    except OSError:
      pass

  try:
    result = 0
    result = os.lstat("/var/lib/vz/private/%d" % veid)
  except OSError:
    pass
  if result != 0:
    log("/var/lib/vz/private/%d already exists.  Ditching it." % veid)
    os.system("VEID=%d /etc/vz/conf/%s" % (veid,ovz_session_umount))
    try:
      os.lstat("/var/lib/vz/private/%d" % veid)
      log("/var/lib/vz/private/%d still exists.  Giving up." % veid)
      sys.exit(1)
    except OSError:
      pass

  if os.system("netstat -tnl | grep -q ' 0.0.0.0:%d '" % (disp+4000)) == 0:
    status = 0
    log("stunnel already running")
  else:
    if os.path.exists("/usr/bin/stunnel4"):
        stunnelCmd = "/usr/bin/stunnel4"
    else:
        stunnelCmd = "stunnel"

    # Automatically detect whether ipv6 is enabled on loopback interface
    # some systems will default to this for localhost so we need to listen
    # on both ipv4 and ipv6

    proc = subprocess.Popen(["/sbin/sysctl", "-e",  "-n",  "net.ipv6.conf.lo.disable_ipv6"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    out, err = proc.communicate()

    if "0" in out:
        prefix = ":::"
    else:
        prefix = ""

    # stunnel4 uses a configurtion file, build a string and pass the file to the process (stunnel3 used command line args)
    # see if FIPS is enabled on stunnel (grep hack)
    proc = subprocess.Popen([stunnelCmd, "-version"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    out, err = proc.communicate()

    if "fips" in out.lower():
        fips_pretext = "fips=no\n"
    else:
        fips_pretext = ""

    stunnelConfigFile = """%s[hzvnc]
cert = /etc/mw-service/xvnc.pem
accept = %s%d
connect = %s:%d
""" % (fips_pretext, prefix, 4000+disp, veaddr, 5000)

    log("stunnel -d %s%d -r %s:%d -p /etc/mw-service/xvnc.pem" % (prefix, 4000+disp, veaddr, 5000))

    f1 = tempfile.NamedTemporaryFile(delete=True)
    f1.write(stunnelConfigFile)
    f1.flush()
    status = subprocess.call([stunnelCmd, f1.name])
    f1.close()

    if status != 0:
        log("Can't start stunnel. (already running?)")

  # Start a forwarder to make the display look external.
  # This is only for backward-compatibility.
  os.system("socat tcp4-listen:%d,fork,reuseaddr,linger=0 tcp4:%s:5000 > /dev/null 2>&1 < /dev/null &" % (portbase+disp,veaddr))

  os.umask(0)

  try:
    os.makedirs("/var/lib/vz/root/%d" % veid, mode=0755)
  except:
    os.umask(0077)
    log("Unable to create /var/lib/vz/root/%d" % veid)
    stopvnc(disp)
    sys.exit(1)

  try:
    os.makedirs("/var/lib/vz/private/%d" % veid, mode=0755)
    os.symlink(".root/bin", "/var/lib/vz/private/%d/bin" % veid)
    os.mkdir("/var/lib/vz/private/%d/home" % veid, 0755)
    os.symlink(".root/lib", "/var/lib/vz/private/%d/lib" % veid)
    os.symlink(".root/emul", "/var/lib/vz/private/%d/emul" % veid)
    os.symlink(".root/lib64", "/var/lib/vz/private/%d/lib64" % veid)
    os.symlink(".root/lib32", "/var/lib/vz/private/%d/lib32" % veid)
    os.mkdir("/var/lib/vz/private/%d/mnt" % veid, 0755)
    os.mkdir("/var/lib/vz/private/%d/opt" % veid, 0755)
    os.mkdir("/var/lib/vz/private/%d/proc" % veid, 0755)
    os.mkdir("/var/lib/vz/private/%d/.root" % veid, 0755)
    os.symlink(".root/sbin", "/var/lib/vz/private/%d/sbin" % veid)
    os.mkdir("/var/lib/vz/private/%d/sys" % veid, 0755)
    #os.symlink(".root/usr", "/var/lib/vz/private/%d/usr" % veid)
  except:
    os.umask(0077)
    log("Unable to create template in /var/lib/vz/private/%d" % veid)
    stopvnc(disp)
    sys.exit(1)

  os.umask(0077)
  
  #
  # Get rid of these links if they already exist.
  #
  try:
    os.unlink("/etc/vz/conf/%d.conf" % veid)
  except:
    pass
  try:
    os.unlink("/etc/vz/conf/%d.mount" % veid)
  except:
    pass
  try:
    os.unlink("/etc/vz/conf/%d.umount" % veid)
  except:
    pass

  #
  # Recreate the links.
  #
  try:
    os.symlink(ovz_session_conf, "/etc/vz/conf/%d.conf" % veid)
    os.symlink(ovz_session_mount, "/etc/vz/conf/%d.mount" % veid)
    os.symlink(ovz_session_umount, "/etc/vz/conf/%d.umount" % veid)
  except:
    log("Unable to create OpenVZ symlink")
    stopvnc(disp)
    sys.exit(1)

  status = os.makedirs("/var/lib/vz/lock/mount.%d.lock" % veid)
  status = os.system("vzctl start %d" % veid)
  if status == 0:
    os.mkdir("/var/lib/vz/root/%d/usr" % veid)
    os.system("/bin/mount -n --bind /var/lib/vz/root/%d/.root/usr /var/lib/vz/root/%d/usr" % (veid, veid))
    for attempt in range(0,50):
      try:
        status = os.stat("/var/lib/vz/lock/mount.%d.lock")
        time.sleep(1)
      except:
        log("TIME: %d seconds." % attempt)
        break
    status = os.system("cat %s/pass.%d | vzctl exec2 %d /usr/bin/startxvnc %s 0 %s" % (vncdir, veid, veid, veaddr, geom))
    if status != 0:
      log("Unable to start internal Xvnc server")
  else:
    log("Bad status for vzctl start %d: %d" % (veid,status))
    stopvnc(disp)
    sys.exit(1)

  status = os.system("vzctl set %d --ipadd %s" % (veid,veaddr))
  if status != 0:
    log("Bad status for vzctl set %d --ipadd %s: %d" % (veid,veaddr,status))
    stopvnc(disp)
    sys.exit(1)

  status = os.system("vzctl exec2 %d /sbin/ifup lo --force" % (veid))
  if status != 0:
    log("Unable to bring up loopback network interface")

  # Start a forwarder for filexfer.  We never kill it.
  # If we can't start one, that means there's already one running.
  port = disp + 9000
  os.system("socat tcp4-listen:%d,fork,reuseaddr,linger=0 tcp4:%s:%d > /dev/null 2>&1 &" % (port,veaddr,port))
  sys.exit(0)

#=============================================================================
# Stop a VNC server and kill the processes for that display.
#=============================================================================
def stopvnc(disp):
  veid = disp + vzoffset

  def vzproccount(veid):
    try:
      f=open("/proc/vz/veinfo")
    except:
      log("vzproccount: an't open veinfo.")
      return 0
    while 1:
      line=f.readline()
      if line == "":
        log("End of file.")
        return 0
      arr=line.split()
      if len(arr) != 4:
        continue
      if arr[0] == str(veid):
        #log("vzproccount is %s" % arr[2])
        try:
          return int(arr[2])
        except:
          return 0
    log("End of function.")
    return 0

  # While there are any processes (other than init) running the VPS, kill all.
  sig = 0
  nextsig={-1:-1, 0:1, 1:2, 2:15, 15:9, 9:-1}
  while vzproccount(veid) > 1:
    sig = nextsig[sig]
    if sig == -1:
      break
    log("Killing %d processes in veid %d with signal %d" % (vzproccount(veid),veid,sig))
    os.system("vzctl exec %d kill -%d -1" % (veid,sig))
    time.sleep(0.5)

  # Stop the VNC server and the VPS.
  os.system("vzctl exec %d halt -nf" % veid)
  os.system("vzctl stop %d" % veid)
  if os.path.isdir("/var/lib/vz/root/%d" % veid):
    os.rmdir("/var/lib/vz/root/%d" % veid)
  try:
    os.unlink("/etc/vz/conf/%d.conf" % veid)
  except:
    pass
  try:
    os.unlink("/etc/vz/conf/%d.mount" % veid)
  except:
    pass
  try:
    os.unlink("/etc/vz/conf/%d.umount" % veid)
  except:
    pass
  return

#=============================================================================
# Clean up sessions that are believed to be running after the system reboots.
#=============================================================================
def notify_restarted(host):
  os.environ["SSH_CONNECTION"]="%s" % host
  os.chdir("/var/log/mw-service/open-sessions")
  list=os.listdir(".")
  for file in list:
    if file.endswith(".err"):
      arr=file.split(".")
      try:
        num=int(arr[0])
      except:
        continue

      log("Cleaning up session %d" % num)
      fp = open(file, 'r')
      has_exit_status=0
      line=fp.readline()
      while line != '':
        if line.startswith("Exit_Status:"):
          has_exit_status=1
        line=fp.readline()
      fp.close()
      if has_exit_status == 0:
        log("Adding exit status to '%s'" % file)
        fp = open(file, 'a')
        fp.write("\nSystem crashed\nExit_Status: 65534")
        fp.close()
      for host in get_notify_hosts():
        if notify_command_finished(num,host):
	  return
        time.sleep(1)

def screenshot(user, session_id, display):
  """Support display of session screenshots for app UI"""
  try:
    info = pwd.getpwnam(user)
    homedir = info[5]
  except KeyError:
    log("screenshot: can't find account information for '%s'\n." % user)
    return 

  destination = "%s/data/sessions/%s/screenshot.png" % (homedir, session_id)  
  veid = int(display) + vzoffset
  veaddr="10.%d.%d.%d" % (machine_number, veid/100, veid % 100)
  xdisplay="%s:0" % veaddr
  
  status = os.system("""/usr/sbin/vzctl exec2 %s "su %s -s /bin/sh -c 'cd; export DISPLAY=%s; /usr/bin/screenshot %s'\"""" % (display, user, xdisplay, destination))

  if status != 0:
    log("screenshot command failed")

#=============================================================================
#=============================================================================
# Main program...
#
# We recognize nine distinct commands:
#  startvnc <dispnum> <geometry> <depth>
#  stopvnc <dispnum>
#  startxapp <user> <sessionid> <timeout> <dispnum> <command>...
#  setup_dir <user> <sessionid>
#  update_resources <user> <sessionid>
#  killtree <pid>
#  killall <user>
#  check
#  purgeoutputs <sessionid>
#  notifyrestarted <masterhostname>
#=============================================================================
#=============================================================================

if len(sys.argv) == 2 and sys.argv[1] == "check":
  print "OK"
  sys.exit(0)

openlog()

if (machine_number == 0):
  log("machine_number not set and unable to derive one from IP address.")
  sys.exit(1)

uid = os.geteuid()
login =  pwd.getpwuid(uid)[0]
if uid != 0:
  print "maxwell_service: access denied to %s. Must be run with an effective user id of 0 (root)." % login
  os._exit(100)

if len(sys.argv) < 2:
  log("Incomplete command: %s" % l2s(sys.argv))
  sys.exit(1)

elif sys.argv[1] == "startvnc":
  if len(sys.argv) < 5:
    log("Incomplete command: %s" % l2s(sys.argv))
    sys.exit(1)
  disp = int(sys.argv[2])
  geom = sys.argv[3]
  depth = int(sys.argv[4])
  startvnc(disp,geom,depth)

elif sys.argv[1] == "stopvnc":
  if len(sys.argv) != 3:
    log("Incomplete command: %s" % l2s(sys.argv))
    sys.exit(1)
  disp = int(sys.argv[2])
  stopvnc(disp)
  sys.exit(0)

elif sys.argv[1] == "screenshot":
  if len(sys.argv) < 5:
    log("Incomplete command: %s" % l2s(sys.argv))
    sys.exit(1)
  screenshot(sys.argv[2],sys.argv[3],sys.argv[4])

elif sys.argv[1] == "startxapp":
  if len(sys.argv) < 7:
    log("Incomplete command: %s" % l2s(sys.argv))
    sys.exit(1)
  user = sys.argv[2]
  sessionid = sys.argv[3]
  timeout = int(sys.argv[4])
  disp = int(sys.argv[5])
  command = sys.argv[6]
  for i in range(7,len(sys.argv)):
    command += " " + sys.argv[i]
  startxapp(user,sessionid,timeout,command,disp)

elif sys.argv[1] == "setup_dir":
  if len(sys.argv) < 4 or len(sys.argv) > 5:
    log("Incomplete command: %s" % l2s(sys.argv))
    sys.exit(1)
  user = sys.argv[2]
  session = sys.argv[3]
  params = ""
  if len(sys.argv) == 5:
    params = sys.argv[4]
  setup_dir(user,session,params)

elif sys.argv[1] == "create_userhome":
  if len(sys.argv) != 3:
    log("Incomplete command: %s" % l2s(sys.argv))
    sys.exit(1)
  user = sys.argv[2]
  create_userhome(user)

elif sys.argv[1] == "update_resources":
  if len(sys.argv) != 4:
    log("Incomplete command: %s" % l2s(sys.argv))
    sys.exit(1)
  user = sys.argv[2]
  sess = sys.argv[3]
  update_resources(user,sess)

elif sys.argv[1] == "killtree":
  if len(sys.argv) != 3:
    log("Incomplete command: %s" % l2s(sys.argv))
    sys.exit(1)
  pid = int(sys.argv[2])
  killtree(pid)

elif sys.argv[1] == "killall":
  if len(sys.argv) != 3:
    log("Incomplete command: %s" % l2s(sys.argv))
    sys.exit(1)
  user = sys.argv[2]
  killall(user)

elif sys.argv[1] == "purgeoutputs":
  if (len(sys.argv)) != 3:
    log("Incomplete command: %s" % l2s(sys.argv))
    sys.exit(1)
  sess=int(sys.argv[2])
  try:
    os.unlink("/var/log/mw-service/open-sessions/%d.out" % sess)
  except:
    pass
  try:
    os.unlink("/var/log/mw-service/open-sessions/%d.err" % sess)
  except:
    pass
  sys.exit(0)

elif sys.argv[1] == "notifyrestarted":
  if (len(sys.argv)) != 3:
    log("Incomplete command: %s" % l2s(sys.argv))
    sys.exit(1)
  notify_restarted(sys.argv[2])
  sys.exit(0)

elif sys.argv[1] == "update_quota":
  if (len(sys.argv)) != 5:
    log("Incomplete command: %s" % l2s(sys.argv))
    sys.exit(1)

  inputs = {}
  inputs["block_soft"] = int(sys.argv[3])
  inputs["block_hard"] = int(sys.argv[4])
  users = sys.argv[2].split(",")
  prog = re.compile(r'\A[0-9a-zA-Z]+[_0-9a-zA-Z\.]*\Z')
  try:
    validated_users = map(lambda x:prog.match(x).group(0), users)
  except AttributeError:
    raise InputError("Invalid user name passed to update_quota: '%s'" % "' '".join(users))

  for user in validated_users:
    inputs["user"] = user
    update_quota(**inputs)

  sys.exit(0)

else:
  print "Unknown command: '%s'." % l2s(sys.argv)
  sys.exit(1)
