# @package      hubzero-submit-monitors
# @file         TunnelMonitor.py
# @author       Steven Clark <clarks@purdue.edu>
# @copyright    Copyright (c) 2012-2013 HUBzero Foundation, LLC.
# @license      http://www.gnu.org/licenses/lgpl-3.0.html LGPLv3
#
# Copyright (c) 2012-2013 HUBzero Foundation, LLC.
#
# This file is part of: The HUBzero(R) Platform for Scientific Collaboration
#
# The HUBzero(R) Platform for Scientific Collaboration (HUBzero) is free
# software: you can redistribute it and/or modify it under the terms of
# the GNU Lesser General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# HUBzero is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# HUBzero is a registered trademark of HUBzero Foundation, LLC.
#

import os
import re
import subprocess
import shlex
import select
import signal

from hubzero.submit.LogMessage  import logID as log
from hubzero.submit.MessageCore import MessageCore

class TunnelMonitor(MessageCore):
   def __init__(self,
                host,
                port,
                repeatDelay=5,
                fixedBufferSize=64,
                updateknownHostsCommand="update-known-hosts"):
      bindLabel = "%s:%d" % (host if host else "localhost",port)
      MessageCore.__init__(self,bindHost=host,bindPort=port,bindLabel=bindLabel,repeatDelay=repeatDelay)
      self.fixedBufferSize         = fixedBufferSize
      self.updateknownHostsCommand = updateknownHostsCommand
      self.bufferSize              = 4096


   def __executeCommand(self,
                        command):
      commandArgs = shlex.split(command)
      child = subprocess.Popen(commandArgs,bufsize=self.bufferSize,
                               stdout=subprocess.PIPE,
                               stderr=subprocess.PIPE,
                               close_fds=True)
      childPid   = child.pid
      childout   = child.stdout
      childoutFd = childout.fileno()
      childerr   = child.stderr
      childerrFd = childerr.fileno()

      outEOF = False
      errEOF = False

      outData = []
      errData = []

      while True:
         toCheck = []
         if not outEOF:
            toCheck.append(childoutFd)
         if not errEOF:
            toCheck.append(childerrFd)
         try:
            ready = select.select(toCheck,[],[]) # wait for input
         except select.error,err:
            ready = {}
            ready[0] = []
         if childoutFd in ready[0]:
            outChunk = os.read(childoutFd,self.bufferSize)
            if outChunk == '':
               outEOF = True
            outData.append(outChunk)

         if childerrFd in ready[0]:
            errChunk = os.read(childerrFd,self.bufferSize)
            if errChunk == '':
               errEOF = True
            errData.append(errChunk)

         if outEOF and errEOF:
            break

      pid,err = os.waitpid(childPid,0)
      if err != 0:
         if os.WIFSIGNALED(err):
            log("%s failed w/ signal %d" % (command,os.WTERMSIG(err)))
         else:
            if os.WIFEXITED(err):
               err = os.WEXITSTATUS(err)
            log("%s failed w/ exit code %d" % (command,err))
         log("%s" % ("".join(errData)))

      return(err,"".join(outData),"".join(errData))


   def __executeTunnelCommand(self,
                              command):
      commandArgs = shlex.split(command)
      child = subprocess.Popen(commandArgs,bufsize=self.bufferSize,
                               stdout=subprocess.PIPE,
                               stderr=subprocess.PIPE,
                               close_fds=True)
      childPid   = child.pid
      childout   = child.stdout
      childoutFd = childout.fileno()
      childerr   = child.stderr
      childerrFd = childerr.fileno()

      outEOF = False
      errEOF = False

      outData = []
      errData = []

      err = -1
      while True:
         toCheck = []
         if not outEOF:
            toCheck.append(childoutFd)
         if not errEOF:
            toCheck.append(childerrFd)
         try:
            ready = select.select(toCheck,[],[],0.) # wait for input
         except select.error,err:
            ready = {}
            ready[0] = []
         if childoutFd in ready[0]:
            outChunk = os.read(childoutFd,self.bufferSize)
            if outChunk == '':
               outEOF = True
            outData.append(outChunk)

         if childerrFd in ready[0]:
            errChunk = os.read(childerrFd,self.bufferSize)
            if errChunk == '':
               errEOF = True
            errData.append(errChunk)

         if outEOF and errEOF:
            break

         err = child.poll()
         if err > -1:
            break

      if err == -1:
         pid,err = os.waitpid(childPid,0)
      if err != 0:
         if os.WIFSIGNALED(err):
            log("%s failed w/ signal %d" % (command,os.WTERMSIG(err)))
         else:
            if os.WIFEXITED(err):
               err = os.WEXITSTATUS(err)
            log("%s failed w/ exit code %d" % (command,err))
         log("%s" % ("".join(errData)))

      return(err,"".join(outData),"".join(errData))


   def processRequest(self,
                      channel,
                      tunnelsInfo,
                      sshIdentity,
                      activeTunnels):
      channelClosed = False

      message = self.receiveMessage(channel,self.fixedBufferSize)
      if message:
         if re.match('[AIDTR]:',message):
            try:
               messageType,tunnelName = message.split(':')
               tunnelName = tunnelName.strip()
            except:
               log("Failed AIDTR message request: " + message)
         else:
            log("Failed message request: " + message)
            messageType = ""

         if   messageType == 'A':                        # get tunnel Address and port
            address,port = tunnelsInfo.getSSHTunnelAddressPort(tunnelName)
            self.sendMessage(channel,address + " " + port,self.fixedBufferSize)
         elif messageType == 'I':                        # Increment tunnel usage
            if tunnelName in activeTunnels:
               tunnelPid,useCount = activeTunnels[tunnelName]
               useCount = str(int(useCount) + 1)
               activeTunnels[tunnelName] = (tunnelPid,useCount)
            else:
               useCount = "?"
               gatewayHost,localHost = tunnelsInfo.getSSHTunnelHosts(tunnelName)
               if gatewayHost != "" and localHost != "":
                  updateCommand = self.updateknownHostsCommand + " \'" + gatewayHost + "\'"
                  exitStatus = self.__executeCommand(updateCommand)[0]
                  if exitStatus == 0:
                     updateCommand = self.updateknownHostsCommand + " \'" + localHost + "\'"
                     exitStatus = self.__executeCommand(updateCommand)[0]
                     if exitStatus == 0:
                        sshTunnelCommand = tunnelsInfo.getSSHTunnelCommand(tunnelName,sshIdentity)
                        log(sshTunnelCommand)
                        exitStatus,stdOutput,stdError = self.__executeTunnelCommand(sshTunnelCommand)
                        if exitStatus == 0:
                           sshTunnelPidCommand = tunnelsInfo.getSSHTunnelPidCommand(tunnelName)
                           log(sshTunnelPidCommand)
                           exitStatus,tunnelPid,stdPidError = self.__executeCommand(sshTunnelPidCommand)
                           tunnelPid = tunnelPid.strip()
                           if exitStatus == 0:
                              useCount = "1"
                              activeTunnels[tunnelName] = (tunnelPid,useCount)
                              log("%s tunnel started, pid = %s" % (tunnelName,tunnelPid))
                           else:
                              log(tunnelPid)
                              log(stdPidError)
                        else:
                           log(stdOutput)
                           log(stdError)
            self.sendMessage(channel,useCount,self.fixedBufferSize)
         elif messageType == 'D':                        # Decrement tunnel usage
            if tunnelName in activeTunnels:
               tunnelPid,useCount = activeTunnels[tunnelName]
               useCount = str(int(useCount) - 1)
               if int(useCount) == 0:
                  os.kill(int(tunnelPid),signal.SIGTERM)
                  del activeTunnels[tunnelName]
                  log("%s tunnel stopped, pid = %s" % (tunnelName,tunnelPid))
               else:
                  activeTunnels[tunnelName] = (tunnelPid,useCount)
            else:
               useCount = "0"
            self.sendMessage(channel,useCount,self.fixedBufferSize)
         elif messageType == 'T':                        # Terminate tunnel
            if tunnelName in activeTunnels:
               tunnelPid,useCount = activeTunnels[tunnelName]
               os.kill(int(tunnelPid),signal.SIGTERM)
               del activeTunnels[tunnelName]
               log("%s tunnel stopped, pid = %s" % (tunnelName,tunnelPid))
            else:
               useCount = "0"
            self.sendMessage(channel,useCount,self.fixedBufferSize)
         elif messageType == 'R':                        # Report active jobs
            if tunnelName == "":
               report = ""
               for activeTunnel in activeTunnels:
                  tunnelPid,useCount = activeTunnels[activeTunnel]
                  if report == "":
                     report += activeTunnel + " " + useCount
                  else:
                     report += " : " + activeTunnel + " " + useCount
            else:
               if tunnelName in activeTunnels:
                  tunnelPid,useCount = activeTunnels[tunnelName]
               else:
                  tunnelPid,useCount = ("?","?")
               report = tunnelName + " " + useCount

            reportLength = len(report)
            if self.sendMessage(channel,str(reportLength),self.fixedBufferSize) > 0:
               if reportLength > 0:
                  self.sendMessage(channel,report)
      else:
         try:
            channel.close()
            channelClosed = True
         except:
            log("close channel failed")

      return(channelClosed)


