# @package      hubzero-submit-distributor
# @file         TunnelsInfo.py
# @author       Steven Clark <clarks@purdue.edu>
# @copyright    Copyright (c) 2012-2015 HUBzero Foundation, LLC.
# @license      http://opensource.org/licenses/MIT MIT
#
# Copyright (c) 2012-2015 HUBzero Foundation, LLC.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# HUBzero is a registered trademark of HUBzero Foundation, LLC.
#
import os
import re
import glob
import traceback
import logging

from hubzero.submit.LogMessage import getLogJobIdMessage as getLogMessage

class TunnelsInfo:
   def __init__(self,
                tunnelsPath,
                localPortMinimum=2000,
                localPortMaximum=2048):
      self.logger = logging.getLogger(__name__)

      self.tunnels          = {}
      self.localPortMinimum = localPortMinimum
      self.localPortMaximum = localPortMaximum

      if os.path.isdir(tunnelsPath):
         for tunnelsInfoPath in glob.iglob(os.path.join(tunnelsPath,'*')):
            self.readTunnelsInfoFile(tunnelsInfoPath)
      else:
         for tunnelsInfoPath in glob.iglob(tunnelsPath):
            self.readTunnelsInfoFile(tunnelsInfoPath)

      markedForDeletion = []
      for tunnelName in self.tunnels:
         if self.tunnels[tunnelName]['state'] != 'enabled':
            markedForDeletion.append(tunnelName)
      for tunnelName in markedForDeletion:
         del self.tunnels[tunnelName]
      del markedForDeletion


   def readTunnelsInfoFile(self,
                           tunnelsInfoPath):
      tunnelPattern   = re.compile('(\s*\[)([^\s]*)(]\s*)')
      keyValuePattern = re.compile('( *)(\w*)( *= *)(.*[^\s$])( *)')
      commentPattern  = re.compile('\s*#.*')
      tunnelName      = ""

      if os.path.exists(tunnelsInfoPath):
         try:
            fpInfo = open(tunnelsInfoPath,'r')
            try:
               eof = False
               while not eof:
                  record = fpInfo.readline()
                  if record != "":
                     record = commentPattern.sub("",record)
                     if   tunnelPattern.match(record):
                        tunnelName = tunnelPattern.match(record).group(2)
                        self.tunnels[tunnelName] = {'venue':"",
                                                    'venuePort':"",
                                                    'sshOptions':"",
                                                    'gatewayHost':"",
                                                    'gatewayUser':"",
                                                    'localPortOffset':"",
                                                    'state':'enabled'
                                                   }
                     elif keyValuePattern.match(record):
                        key,value = keyValuePattern.match(record).group(2,4)
                        if key in self.tunnels[tunnelName]:
                           if   isinstance(self.tunnels[tunnelName][key],list):
                              self.tunnels[tunnelName][key] = [e.strip() for e in value.split(',')]
                           elif isinstance(self.tunnels[tunnelName][key],bool):
                              self.tunnels[tunnelName][key] = bool(value.lower() == 'true')
                           elif isinstance(self.tunnels[tunnelName][key],float):
                              self.tunnels[tunnelName][key] = float(value)
                           elif isinstance(self.tunnels[tunnelName][key],int):
                              self.tunnels[tunnelName][key] = int(value)
                           elif isinstance(self.tunnels[tunnelName][key],dict):
                              try:
                                 sampleKey   = self.tunnels[tunnelName][key].keys()[0]
                                 sampleValue = self.tunnels[tunnelName][key][sampleKey]
                              except:
                                 sampleKey   = "key"
                                 sampleValue = "value"
                              self.tunnels[tunnelName][key] = {}
                              for e in value.split(','):
                                 dictKey,dictValue = e.split(':')
                                 if isinstance(sampleKey,int):
                                    dictKey = int(dictKey)
                                 if   isinstance(sampleValue,int):
                                    dictValue = int(dictValue)
                                 elif isinstance(sampleValue,float):
                                    dictValue = float(dictValue)
                                 elif isinstance(sampleValue,bool):
                                    dictValue = bool(dictValue.lower() == 'true')
                                 self.tunnels[tunnelName][key][dictKey] = dictValue
                           else:
                              self.tunnels[tunnelName][key] = value
                        else:
                           message = "Undefined key = value pair %s = %s for tunnel %s" % (key,value,tunnelName)
                           self.logger.log(logging.WARNING,getLogMessage(message))
                  else:
                     eof = True
            except (IOError,OSError):
               self.logger.log(logging.ERROR,getLogMessage("Tunnels configuration file %s could not be read" % \
                                                                                             (tunnelsInfoPath)))
            finally:
               fpInfo.close()
         except (IOError,OSError):
            self.logger.log(logging.ERROR,getLogMessage("Tunnels configuration file %s could not be opened" % \
                                                                                            (tunnelsInfoPath)))
      else:
         self.logger.log(logging.ERROR,getLogMessage("Tunnels configuration file %s is missing" % \
                                                                                (tunnelsInfoPath)))


   def getSSHTunnelAddressPort(self,
                               tunnelName):
      address = ""
      port    = ""
      try:
         port = self.localPortMinimum + int(self.tunnels[tunnelName]['localPortOffset'])
         if port > self.localPortMinimum and port <= self.localPortMaximum:
            address = "127.0." + self.tunnels[tunnelName]['localPortOffset'] + ".0"
      except:
         self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
         self.logger.log(logging.ERROR,getLogMessage("Get sshTunnelAddressPort failed for %s" % (tunnelName)))

      return(address,str(port))


   def getSSHTunnelHosts(self,
                         tunnelName):
      gatewayHost = ""
      localHost   = ""
      try:
         gatewayHost = self.tunnels[tunnelName]['gatewayHost']
         localPort   = self.localPortMinimum + int(self.tunnels[tunnelName]['localPortOffset'])
         if localPort > self.localPortMinimum and localPort <= self.localPortMaximum:
            localHost = "127.0." + self.tunnels[tunnelName]['localPortOffset'] + ".0"
      except:
         self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
         self.logger.log(logging.ERROR,getLogMessage("Build sshTunnelHosts failed for %s" % (tunnelName)))

      return(gatewayHost,localHost)


   def getSSHTunnelCommand(self,
                           tunnelName,
                           sshIdentity):
      sshTunnelCommand = ""
      try:
         venue       = self.tunnels[tunnelName]['venue']
         venuePort   = self.tunnels[tunnelName]['venuePort']
         sshOptions  = self.tunnels[tunnelName]['sshOptions']
         gatewayHost = self.tunnels[tunnelName]['gatewayHost']
         gatewayUser = self.tunnels[tunnelName]['gatewayUser']
         localPort   = self.localPortMinimum + int(self.tunnels[tunnelName]['localPortOffset'])
         if localPort > self.localPortMinimum and localPort <= self.localPortMaximum:
            sshTunnelCommand = "ssh -T -a -x -N -f -i " + sshOptions + " " + sshIdentity + \
                               " -L 127.0." + self.tunnels[tunnelName]['localPortOffset'] + ".0:" + str(localPort) + ":" + \
                               venue + ":" + venuePort + " " + gatewayUser + "@" + gatewayHost
      except:
         self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
         self.logger.log(logging.ERROR,getLogMessage("Build sshTunnelCommand failed for %s" % (tunnelName)))

      return(sshTunnelCommand)


   def getSSHTunnelPidCommand(self,
                              tunnelName):
      sshTunnelPidCommand = ""
      try:
         venue     = self.tunnels[tunnelName]['venue']
         venuePort = self.tunnels[tunnelName]['venuePort']
         localPort = self.localPortMinimum + int(self.tunnels[tunnelName]['localPortOffset'])
         if localPort > self.localPortMinimum and localPort <= self.localPortMaximum:
            sshTunnelPidCommand = "pgrep -u " + str(os.getuid()) + " -f " + \
                                          "127.0." + self.tunnels[tunnelName]['localPortOffset'] + ".0:" + str(localPort) + ":" + \
                                          venue + ":" + venuePort
      except:
         self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
         self.logger.log(logging.ERROR,getLogMessage("Build sshTunnelPidCommand failed for %s" % (tunnelName)))

      return(sshTunnelPidCommand)


