#!/usr/bin/env python
#
# @package      hubzero-submit-server
# @file         submit-server.py
# @author       Rick Kennell <kennell@purdue.edu>
# @copyright    Copyright (c) 2004-2011 Purdue University. All rights reserved.
# @license      http://www.gnu.org/licenses/lgpl-3.0.html LGPLv3
#
# Copyright (c) 2004-2011 Purdue University
# All rights reserved.
#
# This file is part of: The HUBzero(R) Platform for Scientific Collaboration
#
# The HUBzero(R) Platform for Scientific Collaboration (HUBzero) is free
# software: you can redistribute it and/or modify it under the terms of
# the GNU Lesser General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# HUBzero is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# HUBzero is a registered trademark of Purdue University.
#

from OpenSSL import SSL
import sys
import os
import select
import socket
import MySQLdb
import ldap
import time
from scgi import passfd
import resource
import signal
import pwd
import exceptions
import getopt
import math

#=============================================================================
# Set default values.
#=============================================================================
should_fork = 1
debug_flag = 0
foreground_flag = 0
help_flag = 0
pidfile = "/var/run/submit-server.pid"

configdirs = [ "/etc/submit", "/opt/submit", "/tmp/submit", "." ]
configdir = None
listen_ports = [ "tcp://:830", "tls://:831" ]
timeout_interval = 60
heartbeat_interval = 3600
disconnect_time = 0
disconnect_max  = 4*3600
logfile_name = "/var/log/submit/submit.log"
logfile = None

mysql_host = None
mysql_user = None
mysql_password = None
mysql_db = None

ldap_hosts = []
ldap_basedn = ""
ldap_user_dn = ""

db = None
sock = None
listeners = {}
readers = {}
writers = {}
writecp = {}
inbuf = ""
old_inbuf = ""
inoffset = 0
old_inoffset = 0
args = []
argc = 0
envs = {}
workdir = ""
umask = 022
token = ""
testfile_inode = None
testfile_file = None
shared_filesystem = 0
username = None
password = None
authz = 0
homedir = None
auth_attempts = 0
jobid = None
superjob = 0
sessnum = 0
uid = None
gid = None
gidlist = None
remote_ip = None

command_args = []
input_files = {}
output_files = {}
starttime = None

event = "simulation"
ncpus = 1
venue = "any"

load_limit = 510
load_halflife = 3600
load_horizon = 86400

child_pid = 0
writeclose = {}

control = None

command_established = None
ready_to_exit = None
child_has_exited = None
wait_for_low_load = 0
# Assume that a db entry exists.  That's where the jobid came from.
# If we've improperly assumed it exists, then finalizeJob() will
# indicate failure when a job is aborted prematurely.
database_entry_exists = 1
exit_status = 0
cputime = 0
realtime = 0
waittime = 0
print_metrics = 0
stats_recorded = 0

#=============================================================================
# Debugger support.
#=============================================================================
def info(type, value, tb):
    if hasattr(sys, 'ps1') or not sys.stderr.isatty():
        # You are in interactive mode or don't have a tty-like
        # device, so call the default hook
        sys.__exechook__(type, value, tb)
    else:
        import traceback, pdb
        # You are not in interactive mode; print the exception
        traceback.print_exception(type, value, tb)
        print
        # ... then star the debugger in post-mortem mode
        pdb.pm()

#=============================================================================
# Create the database connection.
#=============================================================================
def db_connect():
  global ready_to_exit

  if not mysql_host:
    log("ERROR: MySQL host was not configured.")
    ready_to_exit = 1
    serverExit()

  delay = 1
  maxdelay = 256
  while 1:
    try:
      db = MySQLdb.connect(host=mysql_host, user=mysql_user, passwd=mysql_password, db=mysql_db)
      return db
    except exceptions.SystemExit:
      log("SystemExit in db_connect")
      sys.exit(1)
    except Exception, err:
      log("Exception in db_connect: %s" % str(err))
    time.sleep(delay)
    if delay < maxdelay:
      delay = delay * 2

#=============================================================================
# Dissociate the process from tty, and invoking shell.
#=============================================================================
def daemonize():
  log("Backgrounding process.")
  if os.fork():
    os._exit(0)
  os.setsid()
  os.chdir("/")
  if os.fork():
    os._exit(0)
  try:
    f = open(pidfile,"w")
    f.write("%d\n" % os.getpid())
    f.close()
  except:
    log("Unable to write pid (%d) to %s" % (os.getpid(),pidfile))

#=============================================================================
# Open the log file.
#=============================================================================
def openlog(filename):
  global logfile
  global foreground_flag

  if not foreground_flag:
    try:
      logfile = open(filename,"a+")
      os.close(sys.stdin.fileno())
      os.close(sys.stdout.fileno())
      os.close(sys.stderr.fileno())
      os.dup2(logfile.fileno(), 1)
      os.dup2(logfile.fileno(), 2)
      devnull = open("/dev/null", "rw")
      os.dup2(sys.stdin.fileno(), devnull.fileno())
      return
    except:
      pass

  if not foreground_flag:
    os.write(2,"Logfile open failed.  Remaining in foreground.\n")
  logfile = sys.stderr
  foreground_flag = 1

#=============================================================================
# Log a message.
#=============================================================================
def log(msg):
  try:
    timestamp = "[" + time.asctime() + "] "
    if jobid:
      msg = timestamp + str(jobid) + ": " + msg
    elif remote_ip:
      msg = timestamp + remote_ip + ": " + msg
    else:
      msg = timestamp + "Startup: " + msg

    if logfile:
      logfile.write(msg + "\n")
      logfile.flush()
    else:
      sys.stderr.write(msg + "\n")
  except IOError, err:
    #sys.stderr.write("=============================================IOerror\n")
    pass

def debug(msg):
  if debug_flag:
    log(msg)

#=============================================================================
# MySQL helpers
#=============================================================================
def mysql(c,cmd):
  try:
    count = c.execute(cmd)
    return c.fetchall()
  except MySQLdb.MySQLError, (num, expl):
    log("%s.  SQL was: %s" % (expl,cmd))
    return ()
  except:
    log("Some other MySQL exception.")
    return ()

def mysql_act(c,cmd):
  try:
    count = c.execute(cmd)
    return ""
  except MySQLdb.MySQLError, (num, expl):
    return expl

#=============================================================================
# LDAP helpers
#=============================================================================
def ldap_authenticate(login,pw):
  if login == None or pw == None:
    return (0,0)

  # Check the following note for more info on ldaps...
  # http://sourceforge.net/mailarchive/forum.php?forum_id=4346&max_rows=25&style=flat&viewmonth=200508
  for ldap_host in ldap_hosts:
    try:
      if ldap_host.startswith("ldaps://"):
        ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT,ldap.OPT_X_TLS_ALLOW)
        l = ldap.initialize(ldap_host)
      elif ldap_host.startswith("ldap://"):
        l = ldap.initialize(ldap_host)
      else:
        l = ldap.open(ldap_host)
    except ldap.LDAPError, msg:
      log("%s: %s" % (ldap_host,msg))
      continue

    try:
      status = l.simple_bind(ldap_user_dn % login, pw)
      if status != 1:
        log("Unable to bind to LDAP server as user '%s'" % login)
        return (0,0)
    except ldap.LDAPError, msg:
      log("%s: %s" % (ldap_host,msg))
      continue

    arr = l.search_s(ldap_basedn, ldap.SCOPE_SUBTREE, "uid=%s" % login)
    if len(arr) != 1:
      log("ldap_authenticate: Non-singular result (%d) from LDAP for %s" % (len(arr),login))
      return (0,0)
    (dn,attrs) = arr[0]
    if attrs.has_key('jobsAllowed'):
      session_limit = int(attrs['jobsAllowed'][0])
      return (1,session_limit)
    else:
      log("Can't get LDAP session_limit")
      return(0,0)

  log("Unable to connect to any LDAP server")
  return (0,0)


#=============================================================================
# Things to do file/name mapping.
#=============================================================================
nametofile = {}
filetoname = {}

def mapfile(file,name):
  nametofile[name] = file
  filetoname[file] = name

#=============================================================================
# Functions to deal with reader/writer queues.
#=============================================================================
cpbuf=""
cpoffset=0
def checkpoint(newstr):
  global cpbuf
  cpbuf = cpbuf + newstr
  offset = cpoffset + len(cpbuf) - len(newstr)
  #if newstr == "" or newstr[-1] != '\n':
  #  newstr = newstr + '\n'
  #sys.stdout.write("Checkpoint %d: %s" % (offset, newstr))

def advance(offset):
  global cpbuf
  global cpoffset
  
  if offset < cpoffset:
    log("Big advance problem: offset(%d) < cpoffset(%d)" % (offset,cpoffset))
  elif offset > cpoffset + len(cpbuf):
    log("Big advance problem: offset(%d) > cpoffset(%d) + len(%d)" % (offset, cpoffset, len(cpbuf)))
  else:
    diff = offset - cpoffset
    cpbuf = cpbuf[diff:]
    cpoffset = offset
    #if cpbuf == "":
    #  debug("cpbuf empty. good.")
    #else:
    #  debug("cpbuf len = %d" % len(cpbuf))

def replay(offset):
  global cpbuf
  global cpoffset

  if offset < cpoffset:
    log("Big replay problem: offset(%d) < cpoffset(%d)" % (offset,cpoffset))
  elif offset > cpoffset + len(cpbuf):
    log("Big replay problem: offset(%d) > cpoffset(%d) + len(%d)" % (offset, cpoffset, len(cpbuf)))
  else:
    diff = offset - cpoffset
    cpbuf = cpbuf[diff:]
    cpoffset = offset
    # Don't checkpoint these writes:
    queueCommand("replay %d %d\n" % (cpoffset,len(cpbuf)), 0)
    queueCommand(cpbuf,0)

def queueWrite(x, str):
  global writers
  if not writers.has_key(x):
    writers[x] = ""
  writers[x] += str

def queueCommand(msg, cpt=1):
  if sock:
    queueWrite(sock, msg)
  else:
    log("socket closed.  can't send '%s'" % msg)
    #print "====Socket closed. Queueing. %s" % msg
    pass
  if cpt:
    checkpoint(msg)

def sendMessage(msg):
  queueCommand("message %d\n" % len(msg))
  queueCommand(msg)

def deleteReader(x):
  global readers
  if readers.has_key(x):
    del readers[x]

def deleteWriter(x):
  global writers
  if writers.has_key(x):
    del writers[x]

#=============================================================================
# Kill the child process
#=============================================================================
def killChild():
  if child_pid == 0:
    return
  try:
    os.kill(-child_pid, signal.SIGINT)
    time.sleep(120)
    os.kill(-child_pid, signal.SIGTERM)
    time.sleep(60)
    os.kill(-child_pid, signal.SIGKILL)
  except:
    pass

#=============================================================================
# Invoked with everything exits.
#=============================================================================
def serverExit():
  if jobid:
    deleteControlPort()

  if ready_to_exit:
    killChild()

    unsetupWorkDirectory()

    try:
      sock.send("exit %d\n" % exit_status)
      sock.send("ackexit\n")
      sock.close()
    except:
      pass
  log("Server exiting.")
  # Really get out without doing work to close anything.
  os._exit(0)

#=============================================================================
# Close the socket.  If nothing else is going on, exit.
#=============================================================================
def closeSocket(errors=None):
  global sock

  if errors and not ready_to_exit:
    log("Client left unexpectedly: '%s'" % errors)
  else:
    log("Client left politely.")

  closeFile(sock)
  sock = None

  if ready_to_exit or not command_established:
    debug("closeSocket() invoking serverExit()")
    serverExit()

#=============================================================================
# Notify the client that the child process has exited.
#=============================================================================
def childExit():
  global child_pid
  log("Doing child exit")
  child_pid = 0
  if sock:
    log("childExit() closeSocket()")
    closeSocket()

#=============================================================================
# On any kind of signal, try to exit cleanly.
#=============================================================================
def handleSignal(sig, frame):
  global ready_to_exit
  log("Server was terminated by a signal.")
  parseExitStatus("venue=%s status=65534" % venue)
  finalizeJob()
  ready_to_exit = 1
  debug("handleSignal() invoking serverExit()")
  serverExit()

def handleChild(sig, frame):
  global child_has_exited
  debug("child_has_exited")
  child_has_exited = 1

def setupSignals():
  signal.signal(signal.SIGHUP, handleSignal)
  signal.signal(signal.SIGINT, handleSignal)
  signal.signal(signal.SIGQUIT, handleSignal)
  signal.signal(signal.SIGTERM, handleSignal)
  signal.signal(signal.SIGCHLD, handleChild)

def unsetupSignals():
  signal.signal(signal.SIGHUP, signal.SIG_IGN)
  signal.signal(signal.SIGINT, signal.SIG_IGN)
  signal.signal(signal.SIGQUIT, signal.SIG_IGN)
  signal.signal(signal.SIGTERM, signal.SIG_IGN)
  signal.signal(signal.SIGCHLD, signal.SIG_IGN)

#=============================================================================
# Set up a control socket and listen to it.
#=============================================================================
def setupControlPort():
  global control

  control_name = "/tmp/control%d" % jobid
  try:
    os.stat(control_name)
    os.unlink(control_name)
  except:
    pass

  try:
    control = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    control.bind(control_name)
    control.listen(1)
    mapfile(control,"!control")
    readers[control] = "!control"
    debug("Control port is set up.")
  except:
    log("ERROR: Failed to set up control port.")

def deleteControlPort():
  global control
  closeFile(control)
  control = None
  try:
    os.unlink("/tmp/control%d" % jobid)
    debug("Unlinked control port")
  except:
    pass

#=============================================================================
# Parse the argument array.
#=============================================================================
def parseArgs():
  global command_args
  global ncpus
  global venue
  global event
  global wait_for_low_load
  global print_metrics
  global help_flag
  global db
  if db == None:
    db = db_connect()
  c = db.cursor()

  help_flag=0
  ncpus=1
  ppn=1

  log("Args are:" + str(args))

  i=0
  try:
    while i+1 < len(args):
      i=i+1
      #print "Parsing args[%d]=%s" % (i,args[i])
      if args[i] == "-e":
        i=i+1
      elif args[i] == "-h" or args[i] == "--help":
        help_flag=1
      elif args[i] == "-i":
        i=i+1
        input_files[args[i]] = 0
      elif args[i] == "-j" or args[i] == "--jobid":
        log("User specified --jobid.  Deleted.")
        del args[i]
        try:
          del args[i] # And delete whatever follows it.
        except:
          pass
      elif args[i] == "-l" or args[i] == "--local":
        venue = ""
      elif args[i] == "-M" or args[i] == "--metrics":
        print_metrics=1
      elif args[i] == "-m" or args[i] == "--manager":
        i=i+1
      elif args[i] == "-N" or args[i] == "--ppn":
        i=i+1
        ppn = int(args[i])
      elif args[i] == "-n" or args[i].lower() == "--ncpus":
        i=i+1
        ncpus = int(args[i])
      #elif args[i] == "-o":
      #  i=i+1
      #  output_files[args[i]] = 0
      elif args[i] == "-v" or args[i] == "--venue":
        i=i+1
        venue = args[i]
      elif args[i] == "-W" or args[i] == "--wait":
        wait_for_low_load=1
      elif args[i] == "-w" or args[i] == "--walltime":
        i=i+1
      elif args[i] == "-r" or args[i] == "--redundancy":
        i=i+1
      else:
        if args[i].startswith("-"):
          log("parseArgs: Bad argument: " + args[i])
          pass
        else:
          command_args = args[i:]
          if help_flag:
            command_args = ['-h'] + command_args
          event = command_args[0]
          try:
            while 1:
              ind=event.index('/')
              event=event[ind+1:]
          except:
            pass
          #
          # Note: If this is a well-known executable, we should not prepend
          # the '/' to it.  Someday, we'll fix that.
          #
          event = "/" + event
          break
  except IndexError:
    log("parseArgs: Bad argument index: " + str(i))
  except:
    log("parseArgs: Bad argument: '%s'" % str(args[i]))

#=============================================================================
# Determine whether the filesystem is shared by comparing the inode number
# of the resources file.
#=============================================================================
def checkFilesystem():
  global shared_filesystem
  try:
    #print "Checking " + testfile_file
    (mode,inode,dev,nlinks,u,g,size,a,m,c)=os.lstat(testfile_file)
    if inode == testfile_inode:
      shared_filesystem = 1
      log("The filesystem is shared.")
    else:
      log("The filesystem is NOT shared.")
  except:
    pass

#=============================================================================
# Move any files if the filesystem is not shared.
#=============================================================================
def moveInputFiles():
  if shared_filesystem:
    return

  for name in input_files.keys():
    basename = os.path.basename(name)
    try:
      f=open(basename,"w")
      input_files[name]=f
      mapfile(f,name)
    except:
      sendMessage("Can't open file to transfer: %s" % name)
    else:
      log("importfile %s" % name)
      queueCommand("importfile %s\n" % name)

def setupWorkDirectory():
  global workdir
  if shared_filesystem:
    os.chdir(workdir)
  else:
    workdir = "/tmp/dir_%s_%d" % (username,jobid)
    os.mkdir(workdir, 0700)
    os.chdir(workdir)

#=============================================================================
# Delete the working directory if we've created one.
# NOTA BENE: We must ONLY delete a specific directory that we've created.
#=============================================================================
def unsetupWorkDirectory():
  if username == None:
    log("ERROR: unsetupWorkDirectory does not know the username.")
    return
  if jobid == None:
    log("ERROR: unsetupWorkDirectory does not know the jobid.")
    return
  dir = "/tmp/dir_%s_%d" % (username,jobid)
  debug("unsetupWorkDirectory '%s'" % dir)
  os.chdir("/tmp")
  status=os.system("rm -rf %s" % dir)
  debug("status=%d" % status)

#=============================================================================
# Find all of the files newer than the command start timestamp.
# Treat them as output files.
#=============================================================================
def findOutputFiles():
  for root, dirs, files in os.walk('.'):
    for file in files:
      path = os.path.join(root,file)
      try:
        if os.path.getctime(path) > starttime:
          if os.path.isfile(path) and not os.path.islink(path):
            queueCommand("declare_output_file %s\n" % path)
            output_files[path] = 0
      except:
        pass

#=============================================================================
# Filter any bad environment variables out.
# For now, we only pass SUBMIT_JOB and *_CHOICE
# and set HOME, USER, LOGNAME and PATH to something sane.
#=============================================================================
def filterEnvironment():
  global envs
  global username
  global homedir
  newenvs={}

  for name in envs:
    name=str(name) # Just in case the name is not a string.
    if not name[0].isalpha():
      continue
    if name == "SUBMIT_JOB":
      newenvs[name] = envs[name]
      continue
    if name == "SUBMITVENUES":
      newenvs[name] = envs[name]
      continue
    if not name.endswith("_CHOICE"):
      continue
    newenvs[name] = envs[name]

  newenvs['HOME'] = homedir
  newenvs['USER'] = username
  newenvs['LOGNAME'] = username
  newenvs['PATH'] = "/usr/bin:/bin"
  envs = newenvs

#=============================================================================
# Invoke the child command.
#=============================================================================
def invokeChild():
  global db
  global filenums
  global child_pid
  global control
  global starttime

  # At this point we are done with the database.
  if db != None:
    db.close()
    db = None

  distributor = os.path.join(configdir, "distributor")
  args.insert(1, "-j")
  args.insert(2, str(jobid))

  starttime = time.time()

  inp = os.pipe()
  outp = os.pipe()
  errp = os.pipe()
  statp = os.pipe()
  child_pid = os.fork()
  if child_pid == 0:
    os.setsid()
    os.dup2(inp[0], 0)
    os.dup2(outp[1], 1)
    os.dup2(errp[1], 2)
    os.dup2(statp[1], 3)
    os.close(inp[1])
    os.close(outp[0])
    os.close(errp[0])
    os.close(statp[0])
    os.execvpe(distributor, args, envs)
    os.write(2,"Cannot invoke distributor.")
    sys.exit(1)
  else:
    mapfile( os.fdopen(inp[1],'w',0), "!stdin" )
    mapfile( os.fdopen(outp[0],'r',0), "!stdout" )
    mapfile( os.fdopen(errp[0],'r',0), "!stderr" )
    mapfile( os.fdopen(statp[0],'r',0), "!status" )
    os.close(inp[0])
    os.close(outp[1])
    os.close(errp[1])
    os.close(statp[1])
    readers[nametofile["!stdout"]] = "!stdout"
    readers[nametofile["!stderr"]] = "!stderr"

    if not help_flag:
      setupControlPort()

    # At this point, the I/O is ready to receive things from the client.
    # Inform the client.
    queueCommand("server_ready_for_io\n")
    if not help_flag:
      scheduleHeartbeat()


def invokeLocal():
  global command_established
  command_established = 1
  queueCommand("server_ready_for_io\n")
  scheduleHeartbeat()

def startCommand():
  global help_flag
  global command_established
  command_established = 1

  debug("startCommand")

  if os.getuid() == 0:
    os.setregid(gid,gid)
    os.setgroups(gidlist)
    os.setreuid(uid,uid)
  if help_flag:
    invokeChild()
    return
  if command_args == None:
    startInteractive()
  os.umask(umask)
  checkFilesystem()
  setupWorkDirectory()
  filterEnvironment()
  if input_files:
    if shared_filesystem:
      invokeChild()
    else:
      # When the input_files have all moved, invokeChild() will be called.
      moveInputFiles()
  else:
    invokeChild()


def startInteractive():
  # Don't really know what to do with this yet...
  log("startInteractive: Don't know what to do")
  sys.exit(1)

#=============================================================================
# Parse statistics
#=============================================================================
def parseExitStatus(status_line):
  global help_flag
  global stats_recorded
  global exit_status
  global venue
  global db

  if help_flag:
    return

  if stats_recorded:
    return

  if db == None:
    db = db_connect()
  c = db.cursor()
  status=0
  waittime=0.0
  cputime=0.0
  realtime=0.0
  ncpus=0

  status_line = status_line.strip()
  args = status_line.split()
  for arg in args:
    try:
      (name,value) = arg.split("=")

      if name == "status":
        status = int(value)
        exit_status = status
      elif name == "cputime":
        cputime = float(value)
      elif name == "realtime":
        realtime = float(value)
      elif name == "waittime":
        waittime = float(value)
      elif name == "ncpus":
        ncpus = int(value)
      elif name == "venue":
        venue = value
      else:
        log("Unknown status item: '%s'" % arg)
        continue

    except ValueError:
      log("Erroneous status item (value): '%s'" % arg)
    except:
      log("Erroneous status item: '%s'" % arg)

  metrics = " venue=%s status=%d cpu=%f real=%f wait=%f"%(venue,status,cputime,realtime,waittime)
  log("Job Status:" + metrics)
  if print_metrics and venue != "":
    log("Sending requested metrics.");
    msg="=SUBMIT-METRICS=>"
    msg += " job=%d" % jobid
    msg += metrics
    msg += "\n"
    queueCommand("write %s %d\n" % ("!stdout", len(msg)))
    queueCommand(msg)

  if waittime > 0:
    err = mysql_act(c,"""
      INSERT INTO
      joblog(sessnum,job,superjob,event,start,walltime,venue)
      SELECT sessnum,jobid,superjob,'%s',start,%f,     '%s'
      FROM job WHERE jobid=%d
      """ % ("[waiting]",waittime,venue,jobid))
    if err != "":
      log("ERROR: Unable to create wait time record. (%f)" % waittime)
      log("Error was: %s" % err)

  err = mysql_act(c,"""
    INSERT INTO
    joblog(sessnum,job,superjob,event,start,walltime,cputime,ncpus,status,venue)
    SELECT sessnum,jobid,superjob,'%s',start,%f,     %f,     ncpus,%d,'%s'
    FROM job WHERE jobid=%d
    """ % (event,realtime,cputime,status,venue,jobid))
  if err != "":
    log("ERROR: Unable to copy job %d to joblog" % jobid)
    log("Error was: %s" % err)

  stats_recorded = 1

  db.close()
  db = None
  return

#=============================================================================
# Interpret a command from the network.
#=============================================================================
def eat(ilen):
  global inoffset
  global inbuf
  inoffset = inoffset + ilen
  inbuf = inbuf[ilen:]

def uneat(s):
  global inoffset
  global inbuf
  inoffset = inoffset - len(s)
  inbuf = s + inbuf

def doCommand(more):
  global inbuf
  global old_inbuf
  global inoffset
  global args
  global argc
  global envs
  global workdir
  global umask
  global token
  global jobid
  global testfile_file
  global testfile_inode
  global username
  global password
  global auth_attempts
  global ready_to_exit
  global superjob
  global authz

  inbuf = inbuf + more

  ackcount=0
  cmdcount=0
  while not inbuf == "":
    cmdcount = cmdcount + 1

    try:
      nl=inbuf.index("\n")
    except:
      break

    arg=inbuf[0:nl].split()
    line = inbuf[0:nl]
    eat(nl+1)

    debug("Parsing '%s'" % line)

    try:
      if arg[0] == "null":
        debug("null")
        pass

      elif arg[0] == "signon":
        try:
          authz = signon()
          queueCommand("authz %d\n" % authz)
        except (ValueError,IndexError):
          log("Signon error: %s" % sys.exc_info()[0])

      elif arg[0] == "resume":
        token = arg[1]
        jobid = int(arg[2])
        lastsent = int(arg[3])
        resumeSession(lastsent)
        cmdcount = -1 # Make sure we do not send an "ack"

      elif arg[0] == "arg":
        alen = int(arg[1])
        if len(inbuf) < alen:
          uneat(line + "\n")
          ackcount=0
          break
        arg = inbuf[0:alen]
        eat(alen)
        args.append(arg)
        debug("Arg: %s" % arg)

      elif arg[0] == "var":
        nlen = int(arg[1])
        vlen = int(arg[2])
        if len(inbuf) < nlen + vlen:
          uneat(line + "\n")
          ackcount=0
          break
        name = inbuf[0:nlen]
        eat(nlen)
        value = inbuf[0:vlen]
        eat(vlen)
        envs[name] = value
        if name == "SUBMIT_JOB":
          try:
            superjob = int(value)
          except:
            pass

      elif arg[0] == "token":
        token = arg[1]
        auth_attempts = auth_attempts + 1

      elif arg[0] == "username":
        username = arg[1]

      elif arg[0] == "password":
        password = arg[1]
        auth_attempts = auth_attempts + 1

      elif arg[0] == "testfile":
        testfile_inode = int(arg[1])
        testfile_file = arg[2]

      elif arg[0] == "pwd":
        workdir = arg[1]

      elif arg[0] == "umask":
        umask = int(arg[1],0)

      elif arg[0] == "startcmd":
        parseArgs()
        checkLoadAndLaunch()

      elif arg[0] == "startlocal":
        parseArgs()
        checkLoadAndLaunch()

      elif arg[0] == "exportfile":
        name = arg[1]
        if shared_filesystem or not output_files.has_key(name):
          log("Improper attempt by client to read %s" % name)
          f = open("/dev/null","r")
        else:
          try:
            f = open(name,"r")
          except:
            sendMessage("Server is unable to open '%s' for export." % name)
            f = open("/dev/null","r")
        mapfile(f,name)
        readers[f] = name

      elif arg[0] == "read":
        name = arg[1]
        ilen = int(arg[2])
        if len(inbuf) < ilen:
          uneat(line + "\n")
          ackcount=0
          break
        buf = inbuf[0:ilen]
        eat(ilen)
        queueWrite(nametofile[name], buf)

      elif arg[0] == "ack":
        sendoffset = int(arg[1])
        recvoffset = int(arg[2])
        ackcount = ackcount + 1
        #print "received ack %d %d" % (sendoffset,recvoffset)
        if recvoffset != inoffset - (len(line) + 1):
          log("INBUF IS OUT OF SYNC")
          log("I think recvoffset=%d" % (inoffset - (len(line) + 1)))
          log("However recvoffset=%d" % recvoffset)
        advance(sendoffset)

      elif arg[0] == "replay":
        roffset = int(arg[1])
        rlen = int(arg[2])
        if len(inbuf) < rlen:
          uneat(line + "\n")
          cmdcount = -1 # Make sure we don't send an "ack"
          break
        else:
          buf = inbuf[0:rlen]
          eat(rlen)
          old_inbuf = old_inbuf + buf + inbuf
          inbuf = old_inbuf
          inoffset = roffset
          old_inbuf = ""

      elif arg[0] == "close":
        name = arg[1]
        writeclose[nametofile[name]] = 1

      elif arg[0] == "exit":
        ready_to_exit = 1
        if command_established:
          unsetupWorkDirectory()
          parseExitStatus("venue=%s status=65534" % venue)
          finalizeJob()
        queueCommand("ackexit\n")

      elif arg[0] == "ackexit":
        debug("ackexit")
        queueCommand("ackexit\n")
        log("ackexit closeSocket()")
        closeSocket()

      elif arg[0] == "localexit":
        status = int(arg[1])
        cputime = float(arg[2])
        realtime = float(arg[3])
        debug("localexit %d %f %f" % (status, cputime, realtime))
        parseExitStatus("venue=%s status=%d cputime=%f realtime=%f" % (venue,status,cputime,realtime))
        finalizeJob()
        ready_to_exit = 1
        queueCommand("ackexit\n")

      else:
        log("Unrecognized server command: '%s'" % line)

    except (IndexError):
      log("Improper server command (index): '%s'" % line)

    except (ValueError):
      log("Improper server command (value): '%s'" % line)

  if ackcount < cmdcount:
    recvoffset = inoffset + len(inbuf)
    sendoffset = cpoffset + len(cpbuf)
    debug("queueing ack %d %d" % (recvoffset,sendoffset))
    queueCommand("ack %d %d\n" % (recvoffset,sendoffset))


#=============================================================================
# Take over a session that's already running.
#=============================================================================
def resumeSession(lastsent):
  global args
  global starttime
  global child_pid
  global workdir
  global username
  global uid
  global gid
  global inbuf
  global inoffset
  global old_inbuf
  global old_inoffset
  global cpbuf
  global cpoffset
  global control
  global ready_to_exit
  global command_established

  log("resumeSession")
  
  # Try this 10 times before failing.
  for retry in range(0,10):
    try:
      control = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
      control.connect("/tmp/control%d" % jobid)
      break
    except:
      time.sleep(1)
      continue
  
    # Failure.  No way to contact original server.  (Maybe it doesn't exist?)
    debug("resumeSession: Can't reconnect to original server")
    serverExit()
    return

  os.write(control.fileno(),"getinfo\n")
  cbuf = ""
  while 1:
    s = os.read(control.fileno(),2048)
    if s == "":
      return 0
    cbuf = cbuf + s

    while cbuf != "":
      try:
        nl=cbuf.index("\n")
      except:
        break

      line = cbuf[0:nl]
      cbuf = cbuf[nl+1:]
      arg = line.split(" ")

      debug("resumeSession control line: %s" % line)

      if arg[0] == "starttime":
        starttime = float(arg[1])

      elif arg[0] == "childpid":
        child_pid = int(arg[1])

      elif arg[0] == "pwd":
        workdir = arg[1]
        os.chdir(workdir)

      elif arg[0] == "arg":
        alen = int(arg[1])
        if len(cbuf) < alen:
          cbuf = line + "\n" + cbuf
          break
        a = cbuf[0:alen]
        cbuf = cbuf[alen:]
        args.append(a)

      elif arg[0] == "gid":
        gid = int(arg[1])
        if os.getuid() == 0:
          os.setregid(gid,gid)

      elif arg[0] == "uid":
        uid = int(arg[1])
        if os.getuid() == 0:
          os.setreuid(uid,uid)

      elif arg[0] == "username":
        username = arg[1]

      elif arg[0] == "checkpoint":
        cpoffset = int(arg[1])
        clen = int(arg[2])
        if len(cbuf) < clen:
          cbuf = line + "\n" + cbuf
          log("Not enough to read")
          break
        else:
          cpbuf = cbuf[0:clen]
          cbuf = cbuf[clen:]
          #newstr = cpbuf
          #if newstr == "":
          #  newstr = "\n"
          #sys.stdout.write("Accepted: %s" % newstr)

      elif arg[0] == "inbuf":
        old_inoffset = int(arg[1])
        ilen = int(arg[2])
        if len(cbuf) < ilen:
          cbuf = line + "\n" + cbuf
          break
        else:
          old_inbuf = cbuf[0:ilen]
          cbuf = cbuf[ilen:]

      elif arg[0] == "shared_filesystem":
        shared_filesystem = int(arg[1])

      elif arg[0] == "readytoquit":
        cfd = control.fileno()
        os.write(cfd, "finish\n")
        for name in [ "!stdin", "!stdout", "!stderr", "!status" ]:
          fd=passfd.recvfd(cfd)
          if name == "!stdin":
            f = os.fdopen(fd,"w",0)
          else:
            f = os.fdopen(fd,"r",0) 
          mapfile(f,name)
        log("resumeSession: received all of the descriptors")

        setupControlPort()

        lastreceived = old_inoffset + len(old_inbuf)
        # Don't checkpoint this write:
        queueCommand("resume %d\n" % lastreceived, 0)
        replay(lastsent)
        # Now that checkpoint is restored, enable stdout and stderr
        for name in [ "!stdout", "!stderr" ]:
          readers[nametofile[name]] = name

        # Find out about output_files.
        parseArgs()
        command_established = 1
        scheduleHeartbeat()
        log("resumeSession: complete")
        return

#=============================================================================
# Synchronously handle a connection from a new controller.
#=============================================================================
def handleController(c):
  cbuf = ""
  while 1:
    ret = os.read(c.fileno(),4096)
    if ret == "":
      c.close()
      return
    cbuf = cbuf + ret
    try:
      nl=cbuf.index("\n")
    except:
      continue
    line = cbuf[0:nl]
    cbuf = cbuf[nl+1:]

    #print "line is %s" % line

    arg=line.split(" ")
    if arg[0] == "abort":
      if child_pid:
        try:
          killChild()
        except:
          pass
        log("Server instructed to kill child %d." % child_pid)
      else:
        debug("received client abort")
        serverExit()
    elif arg[0] == "getinfo":
      os.write(c.fileno(),"starttime %f\n" % starttime)
      os.write(c.fileno(),"childpid %d\n" % child_pid)
      #print "childpid sent"
      os.write(c.fileno(),"umask %d\n" % umask)
      os.write(c.fileno(),"shared_filesystem %d\n" % shared_filesystem)
      for a in args:
        os.write(c.fileno(),"arg %d\n" % len(a))
        os.write(c.fileno(),a)
      os.write(c.fileno(), "gid %d\n" % gid) # must set gid before uid.
      os.write(c.fileno(), "uid %d\n" % uid)
      os.write(c.fileno(), "username %s\n" % username)
      os.write(c.fileno(),"pwd %s\n" % workdir)
      os.write(c.fileno(),"checkpoint %d %d\n" % (cpoffset, len(cpbuf)))
      os.write(c.fileno(),cpbuf)
      #print "cpoffset sent"
      os.write(c.fileno(),"inbuf %d %d\n" % (inoffset, len(inbuf)))
      os.write(c.fileno(),inbuf)
      #print "inbuf sent"
      os.write(c.fileno(),"readytoquit\n")
    elif arg[0] == "finish":
      fd=-1
      for name in [ "!stdin", "!stdout", "!stderr", "!status" ]:
        try:
          f = nametofile[name]
          fd=f.fileno()
          fd=os.dup(fd) # easy way to tell if the descriptor is valid
        except:
          log("Can't dup fd %d (%d)" % (name,fd))
          f=open("/dev/null","rw")
          fd=f.fileno()
        try:
          passfd.sendfd(c.fileno(), fd)
        except:
          log("Unable to pass descriptor %d (%d):" % (name,fd))
          log(sys.exc_info()[0])
          pass
      c.close()
      sys.exit(0)

    else:
      log("Control doesn't recognize: '%s'" % line)

#=============================================================================
# Tell an existing server to exit.
#=============================================================================
def killServer(id):
  try:
    control = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    control.connect("/tmp/control%d" % jobid)
  except:
    log("Unable to kill server with jobid %d" % id)
    sys.exit(1)

  os.write(control.fileno(),"abort\n")

#=============================================================================
# Start the heartbeat timer.
#=============================================================================
def scheduleHeartbeat():
  signal.signal(signal.SIGALRM, heartbeat)
  signal.alarm(heartbeat_interval)
  debug("Heartbeat scheduled for %d seconds" % heartbeat_interval)

def heartbeat(sig, frame):
  global disconnect_time
  global disconnect_max
  global ready_to_exit
  global db
  if db == None:
    db = db_connect()
  c = db.cursor()

  log("Heartbeat for job %d" % jobid)
  err = mysql_act(c, """UPDATE job SET heartbeat=now()
                        WHERE jobid='%d'
                     """ % jobid)
  if err != "":
    log("ERROR: Heartbeat: %s" % err)
  scheduleHeartbeat()
  if sock == None:
    log("Server is in disconnected state.")
    disconnect_time += heartbeat_interval
    if disconnect_time >= disconnect_max:
      log("disconnect time (%d) > disconnect max (%d)" % (disconnect_time, disconnect_max))
      parseExitStatus("venue=%s status=65534" % venue)
      finalizeJob()
      ready_to_exit = 1
      serverExit()
  else:
    debug("still connected")
    disconnect_time = 0
  db.close()
  db = None

def scheduleTimeout(interval=0):
  if interval == -1:
    signal.alarm(0)
    return
  if interval == 0:
    interval = timeout_interval

  if auth_attempts > 3:
    timeout(0,0)
  signal.signal(signal.SIGALRM, handleTimeout)
  interval = int(math.ceil(interval))
  signal.alarm(interval)
  log("Server will time out in %d seconds." % interval)
  return

def handleTimeout(sig,frame):
  global ready_to_exit
  global authz
  if authz == 0:
    log("User '%s' failed to authenticate and timed out." % username)
    ready_to_exit=1
    serverExit()
  elif wait_for_low_load:
    checkLoadAndLaunch()
  else:
    log("Unknown reason for timeout.")
    serverExit()

#=============================================================================
# Handle the session authentication.
#=============================================================================
def signon():
  global username
  global password
  global homedir
  global token
  global jobid
  global uid
  global gid
  global gidlist
  global sessnum
  global db
  if db == None:
    db = db_connect()
  c = db.cursor()

  # RICK: Test code
  #if username=="kennell" and (token=="abcdefg123456" or password=="xyz"):
  #  if not token:
  #    queueCommand("token abcdefg123456\n")
  #  return int(time.time() * 1000)
  #else:
  #  scheduleTimeout()
  #  return 0

  #
  # Check the user with normal Unix conventions.
  #
  try:
    (login,pw,uid,gid,name,homedir,shell) = pwd.getpwnam(username)
  except:
    log("Unable to get info for user '%s@%s'" % (username,remote_ip))
    log("Error is: %s" % sys.exc_info()[0])
    scheduleTimeout()
    return 0

  debug("'%s' is a valid username." % username)

  try:
    r=os.popen("/usr/bin/id %s | sed 's/^.* groups=//' | sed 's/([^)]*)//g'" % username)
    grouptext=r.readline()
    grouptext=grouptext.strip()
    arr=grouptext.split(',')
    gidlist=[]
    for g in arr:
      gidlist.append(int(g))
    debug("gidlist is %s" % str(gidlist))
  except:
    # Not fatal...
    log("Unable to get gidlist for user '%s@%s'" % (username,remote_ip))
    log("Error is: %s" % sys.exc_info()[0])

  arr = mysql(c, """SELECT sessnum FROM session
                    WHERE username='%s' AND sesstoken='%s'
                 """ % (username, token))
  if len(arr) != 0:
    row = arr[0]
    sessnum = int(row[0])
  else:
    arr = createSession(c, username, password, remote_ip, "Submit")
    if len(arr) == 0:
      sendMessage("Unable to create a new session.")
      scheduleTimeout()
      return 0
    elif len(arr) == 2:
      return -1
    else:
      row = arr[0]
      sessnum = int(row[0])
      token = row[1]
      queueCommand("token %s\n" % token)

  if os.getuid() == 0:
    os.setregid(gid,gid)
    os.setgroups(gidlist)
    os.setreuid(uid,uid)

  return uid

#=============================================================================
# Check the load.
#=============================================================================
def checkLoad():
  global db
  if db == None:
    db = db_connect()
  c = db.cursor()

  #
  # Check that the load is low enough
  #
  arr = mysql(c,"""SELECT
                   SUM( POW(.5, timestampdiff(second,start,now())/%d) )
                   FROM job WHERE username='%s' AND
                   timestampdiff(second,start,now()) < %d
                """ % (load_halflife,username,load_horizon))

  if len(arr) != 1:
    log("Error retrieving load for user %s" % username)
    scheduleTimeout()
    return 0

  row = arr[0]
  try:
    load = float(row[0])
  except TypeError:
    # If there are no entries for this user, result will be NULL
    load = 0.0
  if load < load_limit:
    msg = "Cumulative job load is %.2f.  (Max: %.2f)" % (load,load_limit)
    log(msg)
    return 1

  log("User %s cannot exceed load limit %d." % (username,load_limit))
  if not wait_for_low_load:
    sendMessage("You cannot exceed your active job limit of %f. (cur: %f)" % (load_limit,load))
    queueCommand("jobid 0\n")
    scheduleTimeout()
    return 0

  def lg2(x):
    return math.log(x) / math.log(2)

  t = load_halflife * lg2( load / load_limit )
  msg = "Cumulative job load is %.2f.  (Max: %.2f)  Sleeping %d seconds." % (load,load_limit,math.ceil(t))
  log(msg)
  sendMessage(msg)
  scheduleTimeout(t)
  return 0

#=============================================================================
# Insert the job into the job table.
#=============================================================================
def insertJob():
  global superjob
  global sessnum
  global username
  global event
  global ncpus
  global venue
  global db
  if db == None:
    db = db_connect()
  c = db.cursor()

  #
  # Check that the superjob is a job that we own.
  #
  if superjob != 0:
    arr = mysql(c,"""SELECT jobid FROM job WHERE jobid=%d AND sessnum=%d
                  """ % (superjob,sessnum))
    if len(arr) != 1:
      log("ERROR: Trying to claim superjob of %d." % superjob)
      superjob=0

  #
  # Insert the job into the job table.
  #
  err = mysql_act(c,
                  """INSERT INTO job(sessnum,superjob,username,event,ncpus,venue,start,heartbeat)
                     VALUES(%d,%d,'%s','%s',%d,'%s',now(),now())
                  """ % (sessnum, superjob, username, event, ncpus, venue))
  if err != "":
    log("ERROR: Unable to insert job.")
    log("Error was: %s" % err)
    return 0

  arr = mysql(c, """SELECT LAST_INSERT_ID()""")
  row = arr[0]
  jobid = int(row[0])

  err = mysql_act(c, """UPDATE session SET accesstime=now()
                        WHERE sessnum=%d""" % sessnum)
  if err != "":
    log("ERROR: Unable to update session accesstime for jobid %d." % jobid)
    log("Error was: %s" % err)

  return jobid


#=============================================================================
# Tell the client that the job can start.
#=============================================================================
def launchJob():
  global venue
  global jobid
  jobid = insertJob()
  if jobid == 0:
    serverExit()
  updateJob()
  queueCommand("jobid %d\n" % jobid)
  if venue == "":
    invokeLocal()
  else:
    startCommand()

#=============================================================================
# Combination function to check the load and launch job if OK.
#=============================================================================
def checkLoadAndLaunch():
  global help_flag
  global jobid

  if help_flag:
    jobid = 0
    startCommand()
    return

  elif checkLoad():
    launchJob()

#=============================================================================
# Update the job to represent the user's command.
#=============================================================================
def updateJob():
  global event
  global ncpus
  global jobid
  global db
  if db == None:
    db = db_connect()
  c = db.cursor()

  err = mysql_act(c, """UPDATE job
                        SET event='%s',ncpus=%d
                        WHERE jobid=%d""" % (event,ncpus,jobid))
  if err != "":
    log("ERROR: Unable to update job fields for jobid %d." % jobid)
    log("Error was: %s" % err)

#=============================================================================
# Create a session when one doesn't already exist.
#=============================================================================
def createSession(c, login, pw, ip, appname):

  (status,session_limit) = ldap_authenticate(login, pw)
  if status != 1:
    log("LDAP authentication failed for user '%s'" % login)
    return []

  arr = mysql(c,"""
    SELECT COUNT(*) FROM session WHERE username="%s"
              """ % login)
  if len(arr) != 1:
    log("createSession: Non-singleton result from MySQL")
    return []
  row=arr[0]
  session_count = int(row[0])
  if session_count >= session_limit:
    log("User %s cannot exceed session limit of %d" % (login,session_limit))
    sendMessage("You cannot exceed your session limit of %d." % session_limit)
    return ['limit','exceeded']

  err = mysql_act(c,"""
    INSERT INTO
    session(username,remoteip,exechost,dispnum,start,appname,sessname,sesstoken)
    VALUES('%s',     '%s',    '%s',   0,   now(),'%s',   '%s',md5(rand()))
    """ % (login,    ip,      "submit",  appname,appname))

  if err != "":
    log("ERROR: Unable to create session record for '%s'" % login)
    log("Error was: %s" % err)
    sendMessage("Internal database problem.")
    return []

  arr = mysql(c,"""SELECT sessnum,sesstoken
                   FROM session
                   WHERE sessnum=LAST_INSERT_ID()""")

  row = arr[0]
  sessnum = int(row[0])

  err = mysql_act(c,"""INSERT INTO
                       viewperm(sessnum, viewuser, viewtoken)
                       VALUES(  %d,      '%s',     md5(rand()))
                    """ %   (sessnum,   login))

  if err != "":
    log("ERROR: Unable to create viewperm record for '%s'" % login)
    log("Error was: %s" % err)

  return arr

#=============================================================================
# Close the job record.
#=============================================================================
def finalizeJob():
  global help_flag
  global database_entry_exists
  global db
  if help_flag:
    return
  if db == None:
    db = db_connect()
  c = db.cursor()

  if not database_entry_exists:
    return

  arr = mysql(c,"""SELECT sessnum FROM job WHERE jobid=%d""" % jobid)
  if len(arr) == 0:
    log("ERROR: Unable to find session for job.")
  else:
    row = arr[0]
    sessnum = row[0]

    err = mysql_act(c, """UPDATE session SET accesstime=now()
                          WHERE sessnum=%d""" % sessnum)
    if err != "":
      log("ERROR: Unable to update session accesstime for job.")
      log("Error was: %s" % err)

  #err = mysql_act(c,"""DELETE FROM job WHERE jobid=%d""" % jobid)
  err = mysql_act(c,"""UPDATE job SET active=0 WHERE jobid=%d""" % jobid)
  if err != "":
    log("Unable to deactivate job.")
    log("Error was: " + err)

  #
  # Clear inactive jobs beyond the load horizon.
  #
  err = mysql_act(c,"""DELETE FROM job WHERE active=0 AND
                       timestampdiff(second,job.start,now()) > %d
                    """ % load_horizon)
  if err != "":
    log("Unable to clear old inactive jobs.")
    log("Error was: " + err)

  #
  # Mark inactive any jobs that have not had a recent heartbeat.
  #
  err = mysql_act(c,"""UPDATE job SET active=0
                       WHERE active=1
                       AND timestampdiff(second,job.heartbeat,now()) > %d
                    """ % (heartbeat_interval*3))
  if err != "":
    log("Unable to deactivate moribund job entries.")
    log("Error was: " + err)

  db.close()
  db = None

  database_entry_exists = 0

#=============================================================================
# Send and receive a string so we know we have a valid connection.
#=============================================================================
def handshake(f):
  string = "Hello.\n"
  reply = ""

  try:
    # Write the string.
    f.send(string)

    # Expect the same string back.
    reply = f.recv(len(string))
    if reply == string:
      return

  except Exception, err:
    log("handshake(): %s" % reply)
    log("err = %s" % str(err))
    pass

  log("ERROR: Connection handshake failed.  Protocol mismatch?")
  serverExit()


#=============================================================================
# Accept the new connection and fork an independent server.
#=============================================================================
def acceptAndFork(f):
  global db
  global sock
  global remote_ip

  cli,addr = f.accept()
  log("==============================")
  log("Connection to %s from %s" % (filetoname[f],addr))
  (remote_ip,remote_port) = addr

  # Do a double-fork to dissociate from the listening server.
  if should_fork and os.fork() != 0:
    cli.close()   # Close the client socket in the listening server
    os.wait()     # Wait for the intermediate child to exit
  else:
    if should_fork and os.fork() != 0:
      sys.exit(0) # This is the intermediate child.  Exit.
    else:
      # This is the real child.
      if should_fork:
        os.setsid()
      scheduleTimeout()
      sock = cli
      handshake(sock)
      sock.setblocking(0)
      mapfile(sock,"!socket")
      readers[cli] = "!socket"
      closeListeners()

#=============================================================================
#=============================================================================
def acceptController():
  global readers

  log("acceptController")
  csock,addr = control.accept()
  handleController(csock)

#=============================================================================
#=============================================================================
def closeFile(f):
  if writers.has_key(f):
    del writers[f]
  if readers.has_key(f):
    del readers[f]
  if filetoname.has_key(f):
    name = filetoname[f]
    del filetoname[f]
    if nametofile.has_key(name):
      del nametofile[name]
    if input_files.has_key(name):
      del input_files[name]
      if input_files.keys() == []:
        invokeChild()
    if output_files.has_key(name):
      del output_files[name]
  try:
    f.close()
  except:
    pass

#=============================================================================
#=============================================================================
def readItem(f):
  ret = ""

  name = filetoname[f]

  if name == "!socket":
    while 1:
      try:
        ret = f.recv(1024)
        if ret == "":
          closeFile(f)
          return
      except socket.error:
        # Happens on non-blocking TCP socket when there's nothing to read
        break
      except (SSL.WantReadError, SSL.WantWriteError, SSL.WantX509LookupError):
        break
      except SSL.ZeroReturnError:
        log("SSL.ZeroReturnError closeSocket()")
        closeSocket()
        break
      except SSL.Error, errors:
        log("SSL.Error closeSocket()")
        closeSocket(errors)
        break
      else:
        doCommand(ret)

  else:
    try:
      ret = os.read(f.fileno(), 1024)
      if ret == "":
        closeFile(f)
        return
    except:
      log("Unable to read from '%s'" % name)
      closeFile(f)
      queueCommand("close %s\n" % name)
      return
    if ret == "":
      closeFile(f)
      queueCommand("close %s\n" % name)
      return
    #print "Read from %s: %s" % (name,ret)
    #print "=======>>>"
    #print "write %s %d" % (name, len(ret))
    #print ret
    #print "======="
    queueCommand("write %s %d\n" % (name, len(ret)))
    queueCommand(ret)
    return


#=============================================================================
#=============================================================================
def writeItem(f):
  ret = 0
  name = filetoname[f]
  if name == "!socket":
    try:
      ret = f.send(writers[f])
    except (SSL.WantReadError, SSL.WantWriteError, SSL.WantX509LookupError):
      pass
    except SSL.ZeroReturnError:
      log("SSL.ZeroReturnError closeSocket()")
      closeSocket()
    except SSL.Error, errors:
      log("SSL.Error closeSocket()")
      closeSocket(errors)
    else:
      writers[f] = writers[f][ret:]
      if writers[f] == "":
        del writers[f]

  else:
    if writers[f] == "":
      closeFile(f)
      return

    try:
      ret = os.write(f.fileno(), writers[f])
    except KeyError:
      log("Can't write: KeyError")
      pass
    except:
      log("ERROR: writing to %s" % f)
      log("Error is: %s" % sys.exc_info()[0])
      closeFile(f)
    else:
      if writers.has_key(f):
        writers[f] = writers[f][ret:]
        if writers[f] == "":
          del writers[f]

#=============================================================================
# Start the server
#=============================================================================
def parseURL(item):
  try:
    temp = item
    colon = temp.index(":")
    proto = temp[0:colon]
    temp = temp[colon+1:]

    slashes = temp.index("//")
    temp = temp[slashes+2:]

    colon = temp.index(":")
    host = temp[0:colon]
    temp = temp[colon+1:]

    try:
      slash = temp.index("/")
      port = int(temp[0:slash])
    except:
      port = int(temp)

  except:
    log("Improper network specification: %s" % item)
    return ("", "", 0)

  return (proto, host, port)
    

def startListeners():

  for item in listen_ports:
    (proto,host,port) = parseURL(item)

    if proto.lower() == "tcp":
      log("Listening: protocol='%s', host='%s', port=%d" % (proto,host,port))
      s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
      s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
      try:
        s.bind( (host, port) )
      except:
        log("Can't bind to port %d: %s" % (port, sys.exc_info()[0]))
        continue
      s.listen(10)
      s.setblocking(0)
      mapfile(s,item)
      readers[s] = item
      listeners[s] = item

    elif proto.lower() == "tls":
      log("Listening: protocol='%s', host='%s', port=%d" % (proto,host,port))

      ctx = SSL.Context(SSL.TLSv1_METHOD)
      ctx.use_privatekey_file (os.path.join(configdir,"submit_server.key"))
      ctx.use_certificate_file(os.path.join(configdir,"submit_server.crt"))
      ctx.load_verify_locations(os.path.join(configdir,"submit_server_ca.crt"))
      s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
      s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
      server = SSL.Connection(ctx, s)
      server.bind( (host, port) )
      server.listen(10)
      server.setblocking(0)
      mapfile(server,item)
      readers[server] = item
      listeners[server] = item

    else:
      log("Unknown protocol: %s" % item)

def closeListeners():
  for f in listeners:
    closeFile(f)

#=============================================================================
#=============================================================================
# Main loop.  Wait for events.
#=============================================================================
#=============================================================================
def mainloop():
  global db
  global child_pid
  global chlld_exited
  global ready_to_exit
  global stats_recorded

  while 1:
    try:
      #print "select", readers.keys(), writers.keys()
      for x in writeclose.keys():
        if not writers.has_key(x):
          #print "Setting up for close:", str(x)
          writers[x] = ""
          del writeclose[x]
      if child_has_exited:
        select_timeout = .1
      else:
        select_timeout = 15*60
      r,w,_ = select.select(readers.keys(), writers.keys(),[],select_timeout)
      # If the timeout occurs (nothing to read/write) send a keepalive
      if r == [] and w == []:
        if writers.has_key("!socket"):
          queueCommand("null\n")
        continue
    except select.error, err:
      # benign
      r = []
      w = []
      pass
    except IOError, err:
      log("IOError in mainloop: %s" % str(err))
      continue
    except Exception, err:
      log("EXCEPTION IN MAINLOOP: %s" % str(err))
      break

    try:
      for i in r:
        if listeners.has_key(i):
          acceptAndFork(i)

        elif i == control:
          acceptController()

        else:
          readItem(i)

      for o in w:
        writeItem(o)

    except IOError, err:
      log("IOError in mainloop I/O handling: %s" % str(err))
      continue

    if ready_to_exit:
      serverExit()

    if child_pid:
      pid = 0
      #print "Checking child"
      try:
        (pid,status) = os.waitpid(child_pid, os.WNOHANG)
      except:
        try:
          os.kill(child_pid,0)
        except:
          log("Child has exited.")
          pid = child_pid
          status = 1 << 8
      if pid != 0:
        child_pid = 0
        status = status >> 8
        ready_to_exit = 1
        debug("Child exited.  Collecting stats.")

        if nametofile.has_key("!status"):
          f = nametofile["!status"]
          stats = ''
          while 1:
            chunk = os.read(f.fileno(), 10000)
            if chunk == '':
              break
            stats += chunk
          status_lines = stats.split('\n')
          parsed=0
          for line in status_lines:
            if line != "":
              stats_recorded = 0
              parseExitStatus(line)
              parsed=1
          if not parsed:
            parseExitStatus("status=%d" % exit_status)
          closeFile(f)

        debug("Closing outputs.")
        for name in [ "!stdout", "!stderr" ]:
          while nametofile.has_key(name):
            readItem(nametofile[name])
          #print "Done clearing %s" % name

        queueCommand("childexit %d\n" % exit_status)

        if help_flag:
          debug("Queuing help exit.");
          queueCommand("exit %d\n" % exit_status)
          continue

        if not shared_filesystem:
          findOutputFiles()
        finalizeJob()
        queueCommand("childexit %d\n" % exit_status)
        if output_files and not shared_filesystem:
          queueCommand("get_output_files\n")
          log("There are output files to transfer.")
        else:
          # Indicate readiness to exit, but don't do serverExit().
          # The client should reply to tell us it got the exit status.
          ready_to_exit = 1
          queueCommand("exit %d\n" % exit_status)

#=============================================================================
# Handle the arguments that were given to the server.
#=============================================================================
def parseServerArgs():
  global debug_flag
  global foreground_flag
  global jobid
  global configdirs
  i = 0
  while i+1 < len(sys.argv):
    i=i+1
    if sys.argv[i] == "-k":
      i=i+1
      jobid = int(sys.argv[i])
      killServer(jobid)
      sys.exit(0)
    elif sys.argv[i] == "-c":
      i=i+1
      configdirs=[ sys.argv[i] ]
    elif sys.argv[i] == "-d":
      debug_flag = 1
      log("Debug output enabled.")
    elif sys.argv[i] == "-f":
      foreground_flag = 1
      log("Remaining in the foreground.")
    else:
      sys.stderr.write("Unknown argument %s" % sys.argv[i])

#=============================================================================
#=============================================================================
# Main program begins here...
#=============================================================================
#=============================================================================
#sys.exechook = info

def Usage():
  print "%s [-c <configdir>][-d][-f][-k <jobid>]"
  sys.exit(1)

try:
  optlist, cmdlist = getopt.getopt(sys.argv[1:], ':c:dfk:')
except:
  Usage()

for opt in optlist:
  if opt[0] == '-c':
    configdirs=[ opt[1] ]
  elif opt[0] == '-d':
    debug_flag = 1
  elif opt[0] == '-f':
    foreground_flag = 1
  elif opt[0] == '-k':
    try:
      jobid = int(opt[1])
      killServer(jobid)
    except:
      print "Unable to kill '%s'" % jobid
  else:
    Usage()

#=============================================================================
# Load configuration values
#=============================================================================
for configdir in configdirs:
  try:
    if debug_flag:
      log("Searching for config in %s" % os.path.join(configdir,"config"))
    execfile(os.path.join(configdir,"config"))
    if debug_flag:
      log("Found config.")
    break
  except IOError:
    pass

if not foreground_flag:
  openlog(logfile_name)

log("Using configdir %s" % configdir)

if load_limit <= 0:
  log("load_limit must be positive.  Setting to 1.0.")
  load_limit = 1.0

if not foreground_flag:
  daemonize()

startListeners()
setupSignals()
mainloop()
log("Server fell out of mainloop().")
serverExit()

