#!/usr/bin/python
# @package      hubzero-mw2-exec-proxy
# @file         exec-proxy.py
# @author       Pascal Meunier <pmeunier@purdue.edu>
# @copyright    Copyright (c) 2016-2017 HUBzero Foundation, LLC.
# @license      http://opensource.org/licenses/MIT MIT
#
# Packaging of original work by Richard L. Kennell
#
# Copyright (c) 2016-2017 HUBzero Foundation, LLC.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# HUBzero is a registered trademark of HUBzero Foundation, LLC.
#


#==============================================================================
# Exec proxy
# This proxy forks a child process for each connection, interprets the
# HTTP header, and forwards the connection into a container.
#==============================================================================

import socket
import time
import sys
import os
import errno
import traceback
import fcntl
from OpenSSL import SSL
import select
import signal
from hubzero.mw.log import setup_log, log

#
# This config file can override any of the following global constants.
#
EXEC_PROXY_CONFIG_FILE = '/etc/mw-proxy/exec-proxy.conf'
EXEC_CONFIGDIR = '/etc/mw-proxy'

EXEC_LISTEN_HOST = None    # IP address to bind to
EXEC_LISTEN_PORT = None    # port number to bind to
EXEC_LISTEN_SSL = None     # Use SSL for listening socket? (True/False)
EXEC_SERVER_KEY = None     # SSL key
EXEC_SERVER_CERT = None    # SSL certificate
EXEC_SERVER_CACERT = None  # SSL CA certificate
EXEC_SERVER_DHPARAM = None # Diffie-Hellman param
EXEC_SERVER_CIPHERS = None # SSL ciphers to use/exclude
PIDFILE = '/var/run/exec-proxy.pid'
PROXY_LOG = '/var/log/exec-proxy/exec-proxy.log'
HEADER_TIMEOUT = 20

FORWARD_PORT = 8000
RUN_UID = 'hz-exec-proxy'
RUN_GID = 'hz-exec-proxy'
VERBOSITY = 0

from hubzero.mw.constants import MYSQL_CONN_ATTEMPTS, EXEC_CONFIG_FILE
from hubzero.mw.constants import MYSQL_CONN_ATTEMPTS
execfile(EXEC_CONFIG_FILE)
if os.path.isfile(EXEC_PROXY_CONFIG_FILE):
  execfile(EXEC_PROXY_CONFIG_FILE)
else:
  print "configuration file not found, proceeding with defaults"

verbosity = int(VERBOSITY)

#==============================================================================
# Various printing functions.
#==============================================================================
def verbose(str):
  #log(str)
  pass

def fatal(str):
  log('FATAL: ' + str)
  sys.exit(1)

def log_exception(title):
  cla, exc, trbk = sys.exc_info()
  excName = cla.__name__
  log(title + ' ' + excName + ':')
  log(str(exc))
  log('-'*50)
  excTb = traceback.format_tb(trbk, 10)
  for entry in excTb:
    for line in entry.strip().split('\n'):
      log(line.strip())
  log('-'*50)

def print_exception(title):
  print(title)
  print('-'*50)
  cla, exc, trbk = sys.exc_info()
  excName = cla.__name__
  excTb = traceback.format_tb(trbk, 10)
  for entry in excTb:
    for line in entry.strip().split('\n'):
      print(line.strip())
  print('-'*50)

#==============================================================================
# Enable/disable blocking for a SSL, socket, or file descriptor.
#==============================================================================
def setblocking(x, value):
  if 'SSL' in str(type(x)):
    x.setblocking(value)
    return
  if 'socket' in str(type(x)):
    # Replace x with the file descriptor and fall through to cases below.
    x = x.fileno()
  if type(x) != type(1):
    fatal("Don't know how to set blocking of type %s" % str(type(x)))
  if value:
    # Do block on I/O.
    fcntl.fcntl(x,fcntl.F_SETFL, fcntl.fcntl(x,fcntl.F_GETFL) &~os.O_NONBLOCK)
  else:
    # Do not block on I/O.
    fcntl.fcntl(x,fcntl.F_SETFL, fcntl.fcntl(x,fcntl.F_GETFL) | os.O_NONBLOCK)

#==============================================================================
# Create a socket to listen on.  Optionally wrap it with SSL.
#==============================================================================
def create_listener(host,port):
  sock = socket.socket()
  sock.setsockopt(socket.SOL_SOCKET,socket.SO_REUSEADDR,1)
  if EXEC_LISTEN_SSL:
    ctx = SSL.Context(SSL.TLSv1_METHOD)
    ctx.use_privatekey_file (os.path.join(EXEC_CONFIGDIR,EXEC_SERVER_KEY))
    ctx.use_certificate_file(os.path.join(EXEC_CONFIGDIR,EXEC_SERVER_CERT))
    if EXEC_SERVER_CACERT != None:
      ctx.load_verify_locations(os.path.join(EXEC_CONFIGDIR,EXEC_SERVER_CACERT))
    if EXEC_SERVER_DHPARAM != None:
      ctx.load_tmp_dh(os.path.join(EXEC_CONFIGDIR,EXEC_SERVER_DHPARAM))
    if EXEC_SERVER_CIPHERS != None:
      ctx.set_cipher_list(EXEC_SERVER_CIPHERS)
    ctx.set_options(SSL.OP_NO_SSLv2)
    ctx.set_options(SSL.OP_NO_SSLv3)
    ctx.set_options(SSL.OP_SINGLE_DH_USE)
    sock = SSL.Connection(ctx, sock)
  sock.bind((host, port))
  sock.listen(5)
  return sock

#==============================================================================
# Exceptions for reading/writing.
#==============================================================================
class ReaderClose(StandardError):
  """Socket or SSL is closed for reading."""
  pass

class WriterClose(StandardError):
  """Socket or SSL is closed for writing."""
  pass

#==============================================================================
# Read a string from a socket.
#==============================================================================
def read_chunk(s,maxlen=65536):
  try:
    msg = s.recv(maxlen)
  except SSL.ZeroReturnError:
    raise ReaderClose()
  except socket.error, error:
    if error.errno == errno.EAGAIN:
      #log('Caught EAGAIN for %s' % str(type(s)))
      return ''
    else:
      log("read_chunk socket.error" + str(error))
      raise ReaderClose()
  except (SSL.WantReadError, SSL.WantWriteError, SSL.WantX509LookupError), errors:
    #log("read_chunk SSL.WantReadError" + str(errors))
    return ''
  except SSL.Error, errors:
    log("read_chunk SSL.Error " + str(errors))
    raise ReaderClose()
  except OSError, error:
    if error.errno == errno.EAGAIN:
      log('Caught EAGAIN for %s' % str(type(s)))
      return ''
    else:
      log("Error %d in read_chunk" % error.errno)
      raise ReaderClose()

  if msg == '':
    raise ReaderClose()
  return msg

#==============================================================================
# Write a string to a socket.  Return the unwritten part of the string.
#==============================================================================
def write_chunk(s,msg):
  # Upon upgrade from Python 2.7.3 to Python 2.7.9 installation, new errors required adding the "while True" loop
  # Without while loop, this program would log "SSL.Error [('SSL routines', 'ssl3_write_pending', 'bad write retry')]"
  # because it retried a write using a different buffer (even if same contents).  It would then abort and transferred files were truncated
  # Without while loop and without SSL, it would log "FATAL: write_chunk: socket.error [Errno 11] Resource temporarily unavailable"
  # because EAGAIN errors were retried using a different buffer
  # the while loop prevents the buffer from being changed
  # "When you retry a write, you must retry with the exact same buffer"
  # "the same contents are not sufficient and, of course, different contents is absolutely prohibited"
  while True:
    try:
      n = s.send(msg)
      return msg[n:]
    except OSError, error:
      if error.errno == errno.EAGAIN:
        # loop and try again
        continue
      else:
        fatal("Error %d in write_chunk" % error.errno)
    except SSL.WantWriteError:
      # Call select to wait efficiently (instead of busy waiting) until ready for writing
      # select.select(rlist, wlist, xlist[, timeout])
      _, wlist, _ = select.select([], [s], [], 60)
      # check for timeout or other error
      if not wlist:
        fatal("write_chunk: socket not ready to write after SSL.WantWriteError")
      # loop and try writing again
      continue
    except (SSL.WantReadError, SSL.WantX509LookupError):
      return msg
    except SSL.ZeroReturnError:
      fatal("write_chunk: Zero Return")
    except SSL.Error, errors:
      fatal("write_chunk): SSL.Error " + str(errors))
    except socket.error, errors:
      fatal("write_chunk: socket.error " + str(errors))

#==============================================================================
# Send a shutdown only for an SSL socket.  Then close it.
#==============================================================================
def shutdown(s):
  setblocking(s,True)
  if 'SSL' in str(type(s)):
    try:
      s.shutdown()
    except: # Could be that the connection is already shutdown
      pass
  s.close()

#==============================================================================
# Read an HTTP header, return it as an array.
# Any additional data read after the double CRLF is returned as body.
#==============================================================================
debug_header=''
def read_header(ns):
  global debug_header
  chunk=''
  while chunk.rfind('\r\n\r\n') == -1:
    first = (len(chunk) == 0)
    try:
      chunk += read_chunk(ns)
    except ReaderClose:
      break

    if first:
      debug_header = chunk.split('\r')[0]
    if len(chunk) > 100000:
      fatal('Header is too long')
  arr = chunk.split('\r\n\r\n', 1)
  hdr = arr[0].split('\n')
  for n in range(0,len(hdr)):
    hdr[n] = hdr[n].strip()
  if len(hdr) < 1:
    fatal('Malformed header1: ' + str(hdr))
  if len(arr) == 2:
    body = arr[1]
  else:
    body = ''
  return hdr,body

#==============================================================================
# Print a header.
#==============================================================================
def header_print(hdr):
  print('\n'.join(hdr))

#==============================================================================
# Log the full HTTP header.
#==============================================================================
def log_header(hdr,body):
  log('================================')
  for h in hdr:
    log(h)
  if len(body) > 0:
    log(body)

#==============================================================================
# Send the header (and any body data) through a socket/SSL.
#==============================================================================
def send_header(ns,hdr,body):
  if verbosity > 2:
    log_header(hdr,body)
  chunk = '\r\n'.join(hdr) + '\r\n\r\n' + body
  while len(chunk) > 0:
    chunk = write_chunk(ns,chunk)

#==============================================================================
# Find an entry in the header.
#==============================================================================
def header_find(hdr, s):
  for n in range(0,len(hdr)):
    if s in hdr[n]:
      return n
  return -1

#==============================================================================
# Find a key in the header and replace its value, or create it if nonexistent.
#==============================================================================
def header_set(hdr, key, value):
  start = key + ':'
  found = False
  for n in range(0,len(hdr)):
    if hdr[n].startswith(start):
      hdr[n] = start + ' ' + value
      found = True
  if not found:
    hdr += [ start + ' ' + value ]

#==============================================================================
# Return the content-length field or -1 if not found.
#==============================================================================
def header_content_length(hdr):
  idx = header_find(hdr,'Content-Length:')
  if idx < 0:
    return -1
  try:
    return int(hdr[idx][len('Content-Length:'):])
  except ValueError:
    log('Bad Content-Length: ' + hdr[idx])
    return -1

#==============================================================================
# This function is used if we want to verify an SSL server we connect to.
#==============================================================================
def verify_cb(conn, crt, errnum, depth, ok):
  verbose('Got cert: %s' % cert.get_subject())
  verbose('Issuer: %s' % cert.get_issuer())
  verbose('Depth: %s' % str(depth))
  return ok

#==============================================================================
# Determine the internal IP address of a container.
#==============================================================================
def address_of_container(ct):
  try:
    machine_number = int(socket.gethostbyname(socket.getfqdn()).split('.')[3])
    digit = ct/255 + (machine_number % 64) * 4
    ip = CONTAINER_CONF['PRIVATE_NET'] % (digit % 256, ct %255)
    return ip
  except:
    fatal('Unable to determine machine number')

#==============================================================================
# Interpret the header and open a socket/SSL to the execution host's proxy.
#==============================================================================
def find_target(hdr):
  global verbosity
  host = None
  arr = hdr[0].split(' ')
  if len(arr) < 2:
    fatal('Malformed header2: ' + hdr[0])
  action = arr[0]
  url = arr[1]
  params = []
  if url.find('?') != -1:
    arr = url.split('?')
    if len(arr) != 2:
      fatal("Did not expect two '?' in URL")
    url = arr[0]
    params = arr[1].split('&')
  comp=url.split('/')[1:]
  if verbosity > 0:
    log(url)
  if len(comp) < 1 or comp[0] not in [ 'weber', 'notebook' ]:
    fatal('Malformed URL: (comp0) ' + url)
  try:
    ct=int(comp[3])
  except:
    fatal('Malformed URL: (comp3) ' + url)

  # The last octet of the address to connect to is the container number.
  host = address_of_container(ct)

  return host

#==============================================================================
# Establish a socket to the target address.
#==============================================================================
def socket_to_target(host):
  port = FORWARD_PORT
  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  for x in range(1,80):
    try:
      sock.connect((host, port))
      if x > 10:  # If less than a second, don't issue a warning.
        log("Made %d attempts to connect before success." % x)
      #n = header_find(hdr, 'Connection:')
      #if n >= 0:
      #  log("'%s'" % hdr[n])
      return sock
    except:
      time.sleep(0.1 * x)
  fatal("Unable to forward to %s:%d" % (host,port))
  return False

#==============================================================================
# Copy a plaintext socket or SSL socket bidirectionally.
#==============================================================================
def bidirectional_copy(a,b):
  setblocking(a, False)
  setblocking(b, False)
  af = a.fileno()
  bf = b.fileno()
  amsg = ''
  bmsg = ''
  rfds = [af,bf]
  wfds = []
  while len(rfds) > 0 or len(wfds) > 0:
    try:
      rd,wr,_ = select.select(rfds,wfds,[])
    except select.error:
      continue
    if af in rd:
      #verbose('af is readable')
      #os.write(1,'>')
      try:
        amsg += read_chunk(a)
        if bf not in wfds:
          wfds += [bf]
          wr += [bf]
      except ReaderClose:
        rfds.remove(af)
        if bf not in wfds:
          wfds += [bf]
        continue
    if bf in rd:
      #verbose('bf is readable')
      #os.write(1,'<')
      try:
        bmsg += read_chunk(b)
        if af not in wfds:
          wfds += [af]
          wr += [af]
      except ReaderClose:
        rfds.remove(bf)
        if af not in wfds:
          wfds += [af]
        continue
    if bf in wr:
      if len(amsg) > 0:
        amsg = write_chunk(b,amsg)
      if len(amsg) == 0:
        wfds.remove(bf)
    if af in wr:
      if len(bmsg) > 0:
        bmsg = write_chunk(a,bmsg)
      if len(bmsg) == 0:
        wfds.remove(af)
    if af not in rfds and len(amsg) == 0:
      shutdown(b)
      if bf in rfds:
        rfds.remove(bf)
    if bf not in rfds and len(bmsg) == 0:
      shutdown(a)
      if af in rfds:
        rfds.remove(af)
  #verbose('bidirectional_copy finished')

#==============================================================================
# Forward content of a particular length from a to b.
#==============================================================================
def forward_body(a, b, blen):
  orig_blen = blen
  setblocking(a, True)
  setblocking(b, True)

  # Read the stream until it's closed.
  if blen == -1:
    while True:
      try:
        chunk = read_chunk(a,4096)
      except ReaderClose:
        return
      while len(chunk) > 0:
        chunk = write_chunk(b, chunk)

  # Read the stream for the precise number of bytes, then check for close.
  while blen > 0:
    try:
      chunk = read_chunk(a,min(blen,4096))
    except ReaderClose:
      return
    blen -= len(chunk)
    while len(chunk) > 0:
      chunk = write_chunk(b, chunk)
  # Check if the incoming socket is now closed.
  # If it is closed, it will throw a ReaderClosed exception.
  setblocking(a, False)
  chunk = read_chunk(a)
  if chunk != '':
    fatal('Read %d excess bytes beyond content length of %d' % (len(chunk),orig_blen))
  setblocking(a, True)

#==============================================================================
# Don't let the connection read too long without a full header.
#==============================================================================
debug_header=''
def timeout(signo,extra):
  fatal('Timeout waiting for header ' + str(debug_header))

#==============================================================================
# Handle a new connection.
#==============================================================================
def handle_connection(ns,remoteip,remoteport):
  signal.signal(signal.SIGALRM, timeout)
  signal.alarm(20)
  hdr,body = read_header(ns)
  signal.alarm(0)

  if hdr == '':
    return

  host = find_target(hdr)
  newconn = socket_to_target(host)
  if not newconn:
    # TODO: issue a 404 response or something.
    ns.close()
    return
  send_header(newconn,hdr,body)
  bidirectional_copy(ns, newconn)


#==============================================================================
# Accept a new connection, and handle it as a child process.
#==============================================================================
def main_loop(ls):
  global PIDFILE
  while True:
    ns,addr = ls.accept()
    if os.fork() == 0:
      if os.fork() == 0:
        PIDFILE = None # Don't let a child delete the pidfile.
        ls.close() # Close the listener in the child
        os.setsid()
        if verbosity > 1:
          log('Connect from ' + addr[0] + ':' + str(addr[1]))
        try:
          handle_connection(ns,addr[0],addr[1])
        except SystemExit:
          pass
        except:
          log_exception("Exception in child:")
        if verbosity > 1:
          log('Closed ' + addr[0] + ':' + str(addr[1]))
      os._exit(0)
    else:
      try:
        os.wait()
        os.wait()
      except:
        pass
    ns.close()

#==============================================================================
# Run the process in the background.
#==============================================================================
def daemonize():
  os.chdir('/')
  if os.fork() == 0:
    os.setsid()
  else:
    os._exit(0)

#==============================================================================
# Create the PID file.
#==============================================================================
def write_pidfile(name):
  f = open(name, 'w')
  f.write(str(os.getpid()) + '\n')
  f.close()

#==============================================================================
# Remove the PID file.
#==============================================================================
def remove_pidfile(name):
  if name == None:
    return
  try:
    os.seteuid(0)
    os.unlink(name)
  except:
    log('Unable to remove %s' % name)

#==============================================================================
# Clean up on termination.
#==============================================================================
def termination_handler(signo,extra):
  if signo != 0:
    log('Terminating on signal %d' % signo)
  remove_pidfile(PIDFILE)
  # There is no good reason for termination.  Consider it a failure.
  os._exit(1)

#==============================================================================
# The main function creates the main listening socket, sets up logging and
# daemonization, creates the pid file, invokes main_loop(), and handles
# top-level exceptions.
#==============================================================================
def main():
  global verbosity
  import syslog
  # write to syslog until we change uid
  syslog.syslog("exec-proxy starting up")
  # Create the listening socket before backgrounding or writing the pidfile.
  try:
    ls = create_listener(EXEC_LISTEN_HOST, EXEC_LISTEN_PORT)
  except:
    print_exception("FATAL: Unable to listen to %s:%d" % (EXEC_LISTEN_HOST, EXEC_LISTEN_PORT))
    os._exit(1)

  signal.signal(signal.SIGINT, termination_handler)
  signal.signal(signal.SIGTERM, termination_handler)
  foreground = False
  for arg in sys.argv:
    if arg == '-f':
      foreground = True
    if arg == '-v':
      verbosity += 1
  if not foreground:
    try:
      fd = os.open("/dev/null", os.O_RDWR)
      os.dup2(fd, 0)
      os.dup2(fd, 1)
      os.dup2(fd, 2)
      os.close(fd)
    except:
      print("FATAL: Unable to use /dev/null for input/output")
      os._exit(1)

  if not foreground:
    daemonize()

  write_pidfile(PIDFILE)
  syslog.syslog("exec-proxy switching to user %s" % RUN_GID)

  try:
    import grp
    os.setegid(grp.getgrnam(RUN_GID).gr_gid)
  except:
    fatal('Unable to set gid to %s' % RUN_GID)

  try:
    import pwd
    os.seteuid(pwd.getpwnam(RUN_UID).pw_uid)
  except:
    fatal('Unable to set uid to %s' % RUN_UID)

  # start writing to log only after changing uid, to avoid writing as root in non-root owned directory
  setup_log(PROXY_LOG, None)
  log('Starting up')
  if verbosity > 0:
    log('Verbosity level %d' % verbosity)

  # Invoke the handler loop.  Catch and print any unusual exceptions.
  try:
    main_loop(ls)
  except SystemExit: # Something called fatal()
    pass
  except:
    log_exception("Exception in server:")

  # Once it's done, clean up.  A zero indicates no signal.
  termination_handler(0,0)

#==============================================================================
# Invoke the main function.
#==============================================================================
main()

