#
# Copyright (c) 2004-2010 Purdue University All rights reserved.
# 
# Developed by: HUBzero Technology Group, Purdue University
#               http://hubzero.org
# 
# HUBzero is free software: you can redistribute it and/or modify it under the terms of the
# GNU Lesser General Public License as published by the Free Software Foundation, either
# version 3 of the License, or (at your option) any later version.
# 
# HUBzero is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU Lesser General Public License for more details.  You should have received a
# copy of the GNU Lesser General Public License along with HUBzero.
# If not, see <http://www.gnu.org/licenses/>.
# 
# GNU LESSER GENERAL PUBLIC LICENSE
# Version 3, 29 June 2007
# Copyright (C) 2007 Free Software Foundation, Inc. <http://fsf.org/>
#
import os
import re
import popen2
import select
import signal

from LogMessage  import logID as log
from MessageCore import MessageCore

class TunnelMonitor(MessageCore):
   def __init__(self,
                host,
                port,
                repeatDelay=5,
                fixedBufferSize=64,
                updateknownHostsCommand="./update-known-hosts"):
      MessageCore.__init__(self,bindHost=host,bindPort=port,repeatDelay=repeatDelay)
      self.fixedBufferSize         = fixedBufferSize
      self.updateknownHostsCommand = updateknownHostsCommand


   def __executeCommand(self,
                        command):
      child      = popen2.Popen3(command,1)
      childPid   = child.pid
      child.tochild.close() # don't need to talk to child
      childout   = child.fromchild
      childoutFd = childout.fileno()
      childerr   = child.childerr
      childerrFd = childerr.fileno()

      outEOF = errEOF = 0
      BUFSIZ = 4096

      outData = []
      errData = []

      while 1:
         toCheck = []
         if not outEOF:
            toCheck.append(childoutFd)
         if not errEOF:
            toCheck.append(childerrFd)
         ready = select.select(toCheck,[],[]) # wait for input
         if childoutFd in ready[0]:
            outChunk = os.read(childoutFd,BUFSIZ)
            if outChunk == '':
               outEOF = 1
            outData.append(outChunk)

         if childerrFd in ready[0]:
            errChunk = os.read(childerrFd,BUFSIZ)
            if errChunk == '':
               errEOF = 1
            errData.append(errChunk)

         if outEOF and errEOF:
            break

      err = child.wait()
      childPid = 0
      if err != 0:
         if os.WIFSIGNALED(err):
            log("%s failed w/ exit code %d signal %d" % (command,os.WEXITSTATUS(err),os.WTERMSIG(err)))
         else:
            err = os.WEXITSTATUS(err)
            log("%s failed w/ exit code %d" % (command,err))
         log("%s" % ("".join(errData)))

      return(err,"".join(outData),"".join(errData))


   def __executeTunnelCommand(self,
                              command):
      child      = popen2.Popen3(command,1)
      childPid   = child.pid
      child.tochild.close() # don't need to talk to child
      childout   = child.fromchild
      childoutFd = childout.fileno()
      childerr   = child.childerr
      childerrFd = childerr.fileno()

      outEOF = errEOF = 0
      BUFSIZ = 4096

      outData = []
      errData = []

      err = -1
      while 1:
         toCheck = []
         if not outEOF:
            toCheck.append(childoutFd)
         if not errEOF:
            toCheck.append(childerrFd)
         ready = select.select(toCheck,[],[],0.) # wait for input
         if childoutFd in ready[0]:
            outChunk = os.read(childoutFd,BUFSIZ)
            if outChunk == '':
               outEOF = 1
            outData.append(outChunk)

         if childerrFd in ready[0]:
            errChunk = os.read(childerrFd,BUFSIZ)
            if errChunk == '':
               errEOF = 1
            errData.append(errChunk)

         if outEOF and errEOF:
            break

         err = child.poll()
         if err > -1:
            break

      if err == -1:
         err = child.wait()
      if err != 0:
         if os.WIFSIGNALED(err):
            log("%s failed w/ exit code %d signal %d" % (command,os.WEXITSTATUS(err),os.WTERMSIG(err)))
         else:
            err = os.WEXITSTATUS(err)
            log("%s failed w/ exit code %d" % (command,err))
         log("%s" % ("".join(errData)))

      return(err,"".join(outData),"".join(errData))


   def processRequest(self,
                      channel,
                      tunnelsInfo,
                      IDENTITY,
                      activeTunnels):
      channelClosed = False

      message = self.receiveMessage(channel,self.fixedBufferSize)
      if message != "":
         if re.match("[AIDTR]:",message):
            try:
               messageType,tunnelName = message.split(':')
               tunnelName = tunnelName.strip()
            except:
               pass
               log("Failed AIDTR message request: " + message)
         else:
            log("Failed message request: " + message)
            messageType = ""

         if   messageType == "A":                        # get tunnel Address and port
            address,port = tunnelsInfo.getSSHTunnelAddressPort(tunnelName)
            self.sendMessage(channel,address + " " + port,self.fixedBufferSize)
         elif messageType == "I":                        # Increment tunnel usage
            if tunnelName in activeTunnels:
               tunnelPid,useCount = activeTunnels[tunnelName]
               useCount = str(int(useCount) + 1)
               activeTunnels[tunnelName] = (tunnelPid,useCount)
            else:
               useCount = "?"
               gatewayHost,localHost = tunnelsInfo.getSSHTunnelHosts(tunnelName)
               if gatewayHost != "" and localHost != "":
                  updateCommand = self.updateknownHostsCommand + " \'" + gatewayHost + "\'"
                  exitStatus = self.__executeCommand(updateCommand)[0]
                  if exitStatus == 0:
                     updateCommand = self.updateknownHostsCommand + " \'" + localHost + "\'"
                     exitStatus = self.__executeCommand(updateCommand)[0]
                     if exitStatus == 0:
                        sshTunnelCommand = tunnelsInfo.getSSHTunnelCommand(tunnelName,IDENTITY)
                        log(sshTunnelCommand)
                        exitStatus,stdOutput,stdError = self.__executeTunnelCommand(sshTunnelCommand)
                        if exitStatus == 0:
                           sshTunnelPidCommand = tunnelsInfo.getSSHTunnelPidCommand(tunnelName)
                           log(sshTunnelPidCommand)
                           exitStatus,tunnelPid,stdPidError = self.__executeCommand(sshTunnelPidCommand)
                           if exitStatus == 0:
                              useCount = "1"
                              activeTunnels[tunnelName] = (tunnelPid,useCount)
                              log("%s tunnel started, pid = %s" % (tunnelName,tunnelPid))
            self.sendMessage(channel,useCount,self.fixedBufferSize)
         elif messageType == "D":                        # Decrement tunnel usage
            if tunnelName in activeTunnels:
               tunnelPid,useCount = activeTunnels[tunnelName]
               useCount = str(int(useCount) - 1)
               if int(useCount) == 0:
                  os.kill(int(tunnelPid),signal.SIGTERM)
                  del activeTunnels[tunnelName]
                  log("%s tunnel stopped, pid = %s" % (tunnelName,tunnelPid))
               else:
                  activeTunnels[tunnelName] = (tunnelPid,useCount)
            else:
               useCount = "0"
            self.sendMessage(channel,useCount,self.fixedBufferSize)
         elif messageType == "T":                        # Terminate tunnel
            if tunnelName in activeTunnels:
               tunnelPid,useCount = activeTunnels[tunnelName]
               os.kill(int(tunnelPid),signal.SIGTERM)
               del activeTunnels[tunnelName]
               log("%s tunnel stopped, pid = %s" % (tunnelName,tunnelPid))
            else:
               useCount = "0"
            self.sendMessage(channel,useCount,self.fixedBufferSize)
         elif messageType == "R":                        # Report active jobs
            if tunnelName == "":
               report = ""
               for activeTunnel in activeTunnels:
                  tunnelPid,useCount = activeTunnels[activeTunnel]
                  if report == "":
                     report += activeTunnel + " " + useCount
                  else:
                     report += " : " + activeTunnel + " " + useCount
            else:
               if tunnelName in activeTunnels:
                  tunnelPid,useCount = activeTunnels[tunnelName]
               else:
                  tunnelPid,useCount = ("?","?")
               report = tunnelName + " " + useCount

            reportLength = len(report)
            if self.sendMessage(channel,str(reportLength),self.fixedBufferSize) > 0:
               if reportLength > 0:
                  self.sendMessage(channel,report)
      else:
         try:
            channel.close()
            channelClosed = True
         except:
            log("close channel failed")

      return(channelClosed)


