# @package      hubzero-submit-common
# @file         TunnelsInfo.py
# @author       Steven Clark <clarks@purdue.edu>
# @copyright    Copyright (c) 2012-2013 HUBzero Foundation, LLC.
# @license      http://www.gnu.org/licenses/lgpl-3.0.html LGPLv3
#
# Copyright (c) 2012-2013 HUBzero Foundation, LLC.
#
# This file is part of: The HUBzero(R) Platform for Scientific Collaboration
#
# The HUBzero(R) Platform for Scientific Collaboration (HUBzero) is free
# software: you can redistribute it and/or modify it under the terms of
# the GNU Lesser General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# HUBzero is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# HUBzero is a registered trademark of HUBzero Foundation, LLC.
#

import os
import re
import traceback

from hubzero.submit.LogMessage import logJobId as log

class TunnelsInfo:
   def __init__(self,
                infoDirectory,
                tunnelsFile,
                localPortMinimum=2000,
                localPortMaximum=2048):
      self.tunnels          = {}
      self.localPortMinimum = localPortMinimum
      self.localPortMaximum = localPortMaximum

      tunnelPattern   = re.compile('(\s*\[)([^\s]*)(]\s*)')
      keyValuePattern = re.compile('( *)(\w*)( *= *)(.*[^\s$])( *)')
      commentPattern  = re.compile('\s*#.*')
      tunnelName      = ""

      tunnelsPath = os.path.join(infoDirectory,tunnelsFile)
      if os.path.exists(tunnelsPath):
         fpInfo = open(tunnelsPath,'r')
         if fpInfo:
            eof = False
            while not eof:
               record = fpInfo.readline()
               if record != "":
                  record = commentPattern.sub("",record)
                  if   tunnelPattern.match(record):
                     tunnelName = tunnelPattern.match(record).group(2)
                     self.tunnels[tunnelName] = {'venue':'',
                                                 'venuePort':'',
                                                 'gatewayHost':'',
                                                 'gatewayUser':'',
                                                 'localPortOffset':''
                                                }
                  elif keyValuePattern.match(record):
                     key,value = keyValuePattern.match(record).group(2,4)
                     if key in self.tunnels[tunnelName]:
                        if   isinstance(self.tunnels[tunnelName][key],list):
                           self.tunnels[tunnelName][key] = [e.strip() for e in value.split(',')]
                        elif isinstance(self.tunnels[tunnelName][key],bool):
                           self.tunnels[tunnelName][key] = bool(value.lower() == 'true')
                        else:
                           self.tunnels[tunnelName][key] = value
                     else:
                        log("Undefined key = value pair %s = %s for tunnel %s" % (key,value,tunnelName))
               else:
                  eof = True
            fpInfo.close()
      else:
         log("Tunnels configuration file %s is missing" % (tunnelsPath))


   def getSSHTunnelAddressPort(self,
                               tunnelName):
      address = ""
      port    = ""
      try:
         port = self.localPortMinimum + int(self.tunnels[tunnelName]['localPortOffset'])
         if port > self.localPortMinimum and port <= self.localPortMaximum:
            address = "127.0." + self.tunnels[tunnelName]['localPortOffset'] + ".0"
      except:
         log(traceback.format_exc())
         log("Get sshTunnelAddressPort failed for %s" % (tunnelName))
      
      return(address,str(port))


   def getSSHTunnelHosts(self,
                         tunnelName):
      gatewayHost = ""
      localHost   = ""
      try:
         gatewayHost = self.tunnels[tunnelName]['gatewayHost']
         localPort   = self.localPortMinimum + int(self.tunnels[tunnelName]['localPortOffset'])
         if localPort > self.localPortMinimum and localPort <= self.localPortMaximum:
            localHost = "127.0." + self.tunnels[tunnelName]['localPortOffset'] + ".0"
      except:
         log(traceback.format_exc())
         log("Build sshTunnelHosts failed for %s" % (tunnelName))
      
      return(gatewayHost,localHost)


   def getSSHTunnelCommand(self,
                           tunnelName,
                           sshIdentity):
      sshTunnelCommand = ""
      try:
         venue       = self.tunnels[tunnelName]['venue']
         venuePort   = self.tunnels[tunnelName]['venuePort']
         gatewayHost = self.tunnels[tunnelName]['gatewayHost']
         gatewayUser = self.tunnels[tunnelName]['gatewayUser']
         localPort   = self.localPortMinimum + int(self.tunnels[tunnelName]['localPortOffset'])
         if localPort > self.localPortMinimum and localPort <= self.localPortMaximum:
            sshTunnelCommand = "ssh -T -a -x -N -f -i " + sshIdentity + \
                               " -L 127.0." + self.tunnels[tunnelName]['localPortOffset'] + ".0:" + str(localPort) + ":" + \
                               venue + ":" + venuePort + " " + gatewayUser + "@" + gatewayHost
      except:
         log(traceback.format_exc())
         log("Build sshTunnelCommand failed for %s" % (tunnelName))

      return(sshTunnelCommand)


   def getSSHTunnelPidCommand(self,
                              tunnelName):
      sshTunnelPidCommand = ""
      try:
         venue     = self.tunnels[tunnelName]['venue']
         venuePort = self.tunnels[tunnelName]['venuePort']
         localPort = self.localPortMinimum + int(self.tunnels[tunnelName]['localPortOffset'])
         if localPort > self.localPortMinimum and localPort <= self.localPortMaximum:
            sshTunnelPidCommand = "pgrep -u " + str(os.getuid()) + " -f " + \
                                          "127.0." + self.tunnels[tunnelName]['localPortOffset'] + ".0:" + str(localPort) + ":" + \
                                          venue + ":" + venuePort
      except:
         log(traceback.format_exc())
         log("Build sshTunnelPidCommand failed for %s" % (tunnelName))

      return(sshTunnelPidCommand)


