#!/usr/bin/python
#
# @package      hubzero-expire-sessions
# @file         expire-sessions
# @author       Rick Kennell <kennell@purdue.edu>
# @author	    John Rosheck <jrosheck@purdue.edu>
# @copyright    Copyright (c) 2005-2011 Purdue University. All rights reserved.
# @license      http://www.gnu.org/licenses/lgpl-3.0.html LGPLv3
#
# Copyright (c) 2005-2011 Purdue University
# All rights reserved.
#
# This file is part of: The HUBzero(R) Platform for Scientific Collaboration
#
# The HUBzero(R) Platform for Scientific Collaboration (HUBzero) is free
# software: you can redistribute it and/or modify it under the terms of
# the GNU Lesser General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# HUBzero is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# HUBzero is a registered trademark of Purdue University.

import os,pwd,sys,time,MySQLdb,signal

os.chdir("/")

foreground = 0

expire_interval = 15
logfile_name = "/var/log/expire-sessions/expire-sessions.log"
logfile = None
pidfile = "/var/run/expire-sessions.pid"

if os.path.exists("/usr/lib/mw/bin/maxwell"):
  mwprog="/usr/lib/mw/bin/maxwell"
elif os.path.exists("/usr/lib/hubzero/bin/maxwell"):
  mwprog="/usr/lib/hubzero/bin/maxwell"
elif os.path.exists("/usr/bin/maxwell"):
  mwprog="/usr/bin/maxwell"
else:
  mwprog=""

if os.path.exists("/etc/mw-www/maxwell.conf"):
  config_file="/etc/mw-www/maxwell.conf"
elif os.path.exists("/etc/hubzero/maxwell.conf"):
  config_file="/etc/hubzero/maxwell.conf"
elif os.path.exists("/etc/mw/maxwell.conf"):
  config_file="/etc/mw/maxwell.conf"
else:
  config_file=""

#==============================================================================
# Default values.
#==============================================================================

mysql_host = ""
mysql_user=""
mysql_password=""
mysql_db=""
mysql_connect_attempts=120

#=============================================================================
# Load the configuration and override the variables above.
#=============================================================================
try:
  execfile(config_file)
except IOError:
  pass

#=============================================================================
# Set up errors to go to the log file.
#=============================================================================
def openlog(filename):
  global logfile
  global foreground

  if not foreground:
    try:
      logfile = open(filename,"a+")
      devnull = open("/dev/null", "rw")
      os.close(sys.stdin.fileno())
      os.close(sys.stdout.fileno())
      os.close(sys.stderr.fileno())
      os.dup2(devnull.fileno(), 0)
      os.dup2(logfile.fileno(), 1)
      os.dup2(logfile.fileno(), 2)
      os.dup2(sys.stdin.fileno(), devnull.fileno())
      return
    except:
      pass

#=============================================================================
# Log a message.
#=============================================================================
def log(msg):
  if not foreground:
    timestamp = "[" + time.asctime() + "] "
  else:
    timestamp = ""
  logfile.write(timestamp + msg + "\n")
  logfile.flush()

#=============================================================================
# Create database connection.
#=============================================================================
def db_connect():
  for x in range(0,mysql_connect_attempts):
    try:
      db = MySQLdb.connect(host=mysql_host, user=mysql_user, passwd=mysql_password, db=mysql_db)
      #log("db_connect finished on iteration %d" % x)
      return db
    except:
      log("Exception in db_connect")
    time.sleep(1)

#=============================================================================
# MySQL helpers
#=============================================================================
def mysql(c,cmd):
  try:
    count = c.execute(cmd)
    if c._warnings > 0:
      log("MySQL warning")
      log("SQL was: %s" % cmd)
    return c.fetchall()
  except MySQLdb.MySQLError, (num, expl):
    log("%s" % expl)
    log("SQL was: %s" % cmd)
    return ()
  except:
    log("Some other MySQL exception.")
    log("SQL was: %s" % cmd)
    return ()

def mysql_act(c,cmd):
  try:
    count = c.execute(cmd)
    return ""
  except MySQLdb.MySQLError, (num, expl):
    return expl

#=============================================================================
# Daemonize.
#=============================================================================
def daemonize():
  #log("Backgrounding process.")
  if os.fork():
    os._exit(0)
  os.setsid()
  os.chdir("/")
  if os.fork():
    os._exit(0)
  try:
    f = open(pidfile,"w")
    f.write("%d\n" % os.getpid())
    f.close()
  except:
    log("Unable to write pid (%d) to %s" % (os.getpid(),pidfile))

#=============================================================================
# Change to different user.
#=============================================================================
def change_user(username):
  if os.getuid() == 0:
    try:
      _,_,uid,gid,_,_,_ = pwd.getpwnam("www-data")
      os.setregid(gid,gid)
      os.setreuid(uid,uid)
    except:
      pass

#=============================================================================
# Handle signals.
#=============================================================================
def sighandler(sig,frame):
  log("Caught signal %d.  Exiting." % sig)
  sys.exit(1)

#=============================================================================
# Do the work of terminating a session.
#=============================================================================
def terminate(sessnum,username,appname):
  log("Terminating %d (%s,%s)" % (sessnum,username,appname))
  status=os.system("%s stop reason=timeout sessnum=%d" % (mwprog,sessnum))
  if status != 0:
    log("Error terminating session: %d" % status)

#=============================================================================
# Main program
#=============================================================================

signal.signal(signal.SIGHUP, sighandler)
signal.signal(signal.SIGINT, sighandler)
signal.signal(signal.SIGQUIT, sighandler)
signal.signal(signal.SIGTERM, sighandler)

if not foreground:
  daemonize()
  change_user("www-data")
  openlog(logfile_name)
  log("Server is ready.")

while 1:
  db = db_connect()
  c = db.cursor()
  arr = mysql(c, """
           SELECT session.sessnum,session.username,appname,
           timeout-TIME_TO_SEC(TIMEDIFF(NOW(), accesstime)) AS remaining
           FROM session
           LEFT JOIN view ON session.sessnum = view.sessnum
           LEFT JOIN job ON session.sessnum = job.sessnum
           WHERE viewid IS NULL AND jobid IS NULL AND
           timeout-TIME_TO_SEC(TIMEDIFF(NOW(), accesstime)) < %d
           ORDER BY REMAINING""" % expire_interval)
  db.close()

  # Recheck, at the longest, every "expire_interval" seconds.
  shortest = expire_interval
  skew=0
  for sessnum,username,appname,remaining in arr:
    if foreground:
      log("%d of %s to expire in %d seconds" % (sessnum,username,remaining))
    if remaining < 1 + skew:
      terminate(sessnum,username,appname)
      time.sleep(1) # Do not kill things off too quickly.
      skew += 1
    elif remaining < shortest:
      shortest = remaining

  shortest -= skew
  if (shortest > 0):
    time.sleep(shortest)

