#!/usr/bin/python
#
# @package      hubzero-expire-sessions
# @file         expire-sessions
# @author       Rick Kennell <kennell@purdue.edu>
# @author       Pascal Meunier <pmeunier@purdue.edu>
# @author       Nicholas J. Kisseberth <nkissebe@purdue.edu>
# @author       David Benham <dbenham@purdue.edu>
# @copyright    Copyright (c) 2005-2015 HUBzero Foundation, LLC.
# @license      http://opensource.org/licenses/MIT MIT
#
# Copyright (c) 2005-2015 HUBzero Foundation, LLC.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# HUBzero is a registered trademark of HUBzero Foundation, LLC.s
#
import os, pwd, sys, time, MySQLdb, signal

os.chdir("/")

foreground = 0
DEBUG = 0

expire_interval = 300
logfile_name = "/var/log/expire-sessions/expire-sessions.log"
logfile = None
pidfile = "/var/run/expire-sessions.pid"

if os.path.exists("/usr/lib/mw/bin/maxwell"):
  mwprog = "/usr/lib/mw/bin/maxwell"
elif os.path.exists("/usr/lib/hubzero/bin/maxwell"):
  mwprog = "/usr/lib/hubzero/bin/maxwell"
elif os.path.exists("/usr/bin/maxwell"):
  mwprog = "/usr/bin/maxwell"
else:
  mwprog = ""

if os.path.exists("/etc/mw-client/mw-client.conf"):
  config_file = "/etc/mw-client/mw-client.conf"
elif os.path.exists("/etc/mw-www/maxwell.conf"):
  config_file = "/etc/mw-www/maxwell.conf"
elif os.path.exists("/etc/hubzero/maxwell.conf"):
  config_file = "/etc/hubzero/maxwell.conf"
elif os.path.exists("/etc/mw/maxwell.conf"):
  config_file = "/etc/mw/maxwell.conf"
else:
  config_file = ""

#==============================================================================
# Default values.
#==============================================================================

mysql_host = ""
mysql_user =""
mysql_password = ""
mysql_db = ""
mysql_connect_attempts = 120

ANONYMOUS_SESSIONS = False

#=============================================================================
# Load the configuration and override the variables above.
#=============================================================================
try:
  execfile(config_file)
except IOError:
  pass

#=============================================================================
# Set up errors to go to the log file.
#=============================================================================
def openlog(filename):
  global logfile
  global foreground

  if not foreground:
    try:
      logfile = open(filename,"a+")
      devnull = open("/dev/null", "rw")
      os.close(sys.stdin.fileno())
      os.close(sys.stdout.fileno())
      os.close(sys.stderr.fileno())
      os.dup2(devnull.fileno(), 0)
      os.dup2(logfile.fileno(), 1)
      os.dup2(logfile.fileno(), 2)
      os.dup2(sys.stdin.fileno(), devnull.fileno())
      return
    except:
      print(" [Unable to write to log file (%s). Exiting...] " % filename),
      sys.exit(3)

#=============================================================================
# Log a message.
#=============================================================================
def log(msg):
  if foreground:
    print msg
  else:
    timestamp = "[" + time.asctime() + "] "
    #timestamp = ""
    logfile.write(timestamp + msg + "\n")
    logfile.flush()

#=============================================================================
# Create database connection.
#=============================================================================
def db_connect():
  for x in range(0, mysql_connect_attempts):
    try:
      db = MySQLdb.connect(host=mysql_host, user=mysql_user, passwd=mysql_password, db=mysql_db)
      #log("db_connect finished on iteration %d" % x)
      return db
    except:
      log("Exception in db_connect")
    time.sleep(1)

#=============================================================================
# MySQL helpers
#=============================================================================
def mysql(c, cmd):
  try:
    count = c.execute(cmd)
    if c._warnings > 0:
      log("MySQL warning")
      log("SQL was: %s" % cmd)
    return c.fetchall()
  except MySQLdb.MySQLError, (count, expl):
    log("%s" % expl)
    log("SQL was: %s" % cmd)
    return ()
  except:
    log("Some other MySQL exception.")
    log("SQL was: %s" % cmd)
    return ()

def mysql_act(c, cmd):
  try:
    count = c.execute(cmd)
    return ""
  except MySQLdb.MySQLError, (count, expl):
    return expl


#=============================================================================
# Check pidfile
#=============================================================================
def checkpidfile():

  try:
    f = open(pidfile,"r")
  except:
    return True

  try:
    pid = int(f.read())
  except:
    #log("Unable to open pidfile %s" % pidfile)
    return False

  if (pid <= 1):
    #log("Invalid pid found in pidfile %s" % pidfile)
    return False

  if os.path.exists("/proc/%s" % pid):
    print(" [expire-sessions (pid %s) already running] " % pid),
    return 0
  else:
    return 1

#=============================================================================
# Remove pidfile
#=============================================================================
def rmpidfile():
  try:
    os.remove(pidfile)
  except:
    #log("Unable to remove pidfile %s" % pidfile)
    return False

  return True

#=============================================================================
# Daemonize.
#=============================================================================
def daemonize():
  #log("Backgrounding process.")
  if os.fork():
    os._exit(0)
  os.setsid()
  os.chdir("/")
  if os.fork():
    os._exit(0)
  try:
    f = open(pidfile,"w")
    f.write("%d\n" % os.getpid())
    f.close()
  except:
    log("Unable to write pid (%d) to %s" % (os.getpid(),pidfile))

#=============================================================================
# Change to different user.
#=============================================================================
def change_www():
  if os.getuid() == 0:
    try:

      if os.path.exists("/etc/debian_version"):
        wwwuser='www-data'
      elif os.path.exists("/etc/redhat-release"):
        wwwuser='apache'
      else:
        print "unknown web user, cannot determine distribution"
        exit(1)

      _, _, uid, gid, _, _, _ = pwd.getpwnam(wwwuser)
      os.setregid(gid, gid)
      os.setreuid(uid, uid)
    except OSError:
      pass

#=============================================================================
# Handle signals.
#=============================================================================
def sighandler(sig, frame):
  log("Caught signal %d.  Exiting." % sig)
  rmpidfile()
  sys.exit(2)

#=============================================================================
# Do the work of terminating a session.
#=============================================================================
def terminate(sessnum, username, appname):
  log("Terminating %d (%s,%s)" % (sessnum, username, appname))
  status = os.system("%s stop reason=timeout sessnum=%d" % (mwprog, sessnum))
  if status != 0:
    log("Error terminating session: %d" % status)

#=============================================================================
# Main program
#=============================================================================

signal.signal(signal.SIGHUP, sighandler)
signal.signal(signal.SIGINT, sighandler)
signal.signal(signal.SIGQUIT, sighandler)
signal.signal(signal.SIGTERM, sighandler)

if not checkpidfile():
  print "checkpidfile failed"
  sys.exit(1)

if not foreground:
  daemonize()
  change_www()
  openlog(logfile_name)

  log("Server is ready.")

# Delete old views without heartbeat
db = db_connect()
c = db.cursor()

while 1:
  # Recheck, at the longest, every "expire_interval" seconds.
  sleeptime = expire_interval
  # delete any orphaned views, which shouldn't exist unless there was a restart or crash
  mysql(c, """DELETE FROM view WHERE heartbeat < DATE_ADD(NOW(), INTERVAL -1 DAY)""")
  
  # Anonymous sessions
  if ANONYMOUS_SESSIONS:
    anon_arr = mysql(c, """
             SELECT session.sessnum, appname
             FROM session LEFT JOIN view ON session.sessnum = view.sessnum
             WHERE session.username = 'anonymous' 
             AND viewid IS NULL""")
    for sessnum, appname in anon_arr:
      session = int(sessnum)
      terminate(sessnum, 'anonymous', appname)
      time.sleep(1) # Do not kill things off too quickly.
      sleeptime -= 1
  
  # Regular sessions
  # Find sessions that have been accessed more than "timeout" seconds ago
  # and that have no views, and that could expire during our next sleep (expire_interval)
  # skew: keep track of sleep time between each expired session
  skew = 0 
  arr = mysql(c, """
    SELECT session.sessnum, session.username, appname,
      timeout - TIMESTAMPDIFF(SECOND, accesstime, NOW()) AS remaining
    FROM session LEFT JOIN view ON session.sessnum = view.sessnum
    WHERE viewid IS NULL 
    AND exechost != 'submit'
    AND timeout - TIMESTAMPDIFF(SECOND, accesstime, NOW()) < %d
    ORDER BY remaining""" % (expire_interval))

  # Check if those sessions have jobs running
  # there can be many jobs per session.  Expire session only if *all* jobs have had no heartbeat.
  # Complication:  submit --local jobs that run as long as the session is open
  # "submit --local" will have no heartbeat, whereas all other submit jobs will
  # To handle this, expire session only if all job heartbeats are old.
  for sessnum, username, appname, session_remaining in arr:
    # Find the greatest remaining time for this session to stay alive
    max_remaining = int(session_remaining)
    # short circuit other checks if this session will stay alive longer than interval to next check
    if max_remaining > sleeptime:
      continue
    # heartbeat is considered "stale" after 24 hours
    # remaining = heartbeat - NOW() + 86400
    job_arr = mysql(c, """SELECT 86400-TIMESTAMPDIFF(SECOND, heartbeat, NOW()) AS remaining
      FROM job WHERE sessnum = "%s"
      AND 86400-TIMESTAMPDIFF(SECOND, heartbeat, NOW()) > 0""" % sessnum)
    if job_arr:
      # Find the greatest reason this session should stay alive
      for row in job_arr:
        remaining = row[0]
        if int(remaining)  > max_remaining:
          max_remaining = int(remaining)
      if max_remaining > sleeptime:
        continue

    # How long should someone have to check the results of a simulation?  
    # people have 24 hours to check results after the end of the last job.  
    # If a job started at 9AM took 12 hours to run then a user would have until 9PM 
    # the next day to check results.  It's always possible to check the results before 
    # the session ends. 
    # if any joblog entry is for a job that ended less than a day ago, pass
    # remaining = start + walltime - NOW() + 86400
    # walltime > TIME_TO_SEC(start) is because it's possible for walltime to be calculated from epoch
    # job > 0 because the middleware makes entries with job=0, we don't want those, not real jobs
    job_arr = mysql(c, """SELECT 86400 + walltime - TIMESTAMPDIFF(SECOND, start, NOW()) AS remaining
      FROM joblog WHERE sessnum = "%s"
      AND walltime < UNIX_TIMESTAMP(start)
      AND job > 0
      AND 86400 + walltime - TIMESTAMPDIFF(SECOND, start, NOW()) > 0""" % sessnum)
    if job_arr:
      for row in job_arr:
        remaining = row[0]
        if int(remaining)  > max_remaining:
          max_remaining = int(remaining)
      # log("%d of %s has a recently ended job %s" % (sessnum, username, job_arr))

    if DEBUG:
      log("%d of %s to expire in %d seconds" % (sessnum, username, max_remaining))
    if max_remaining < 1 + skew:
      terminate(sessnum, username, appname)
      time.sleep(1) # Do not kill things off too quickly.
      skew += 1
    elif max_remaining < sleeptime:
      sleeptime = max_remaining      

  sleeptime -= skew
  if (sleeptime > 0):
    db.close()
    time.sleep(sleeptime)
    db = db_connect()
    c = db.cursor()
