#
# Copyright (c) 2004-2011 Purdue University All rights reserved.
#
# Developed by: HUBzero Technology Group, Purdue University
#               http://hubzero.org
#
# HUBzero is free software: you can redistribute it and/or modify it under the terms of the
# GNU Lesser General Public License as published by the Free Software Foundation, either
# version 3 of the License, or (at your option) any later version.
#
# HUBzero is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU Lesser General Public License for more details.  You should have received a
# copy of the GNU Lesser General Public License along with HUBzero.
# If not, see <http://www.gnu.org/licenses/>.
#
# GNU LESSER GENERAL PUBLIC LICENSE
# Version 3, 29 June 2007
# Copyright (C) 2007 Free Software Foundation, Inc. <http://fsf.org/>
#
import os.path
import copy

from hubzero.submit.LogMessage  import logJobId as log
from hubzero.submit.MessageCore import MessageCore

CERTIFICATEDIRECTORY = os.path.join(os.sep,'etc','submit')
SSLKEYPATH           = os.path.join(CERTIFICATEDIRECTORY,'submit_server.key')
SSLCERTPATH          = os.path.join(CERTIFICATEDIRECTORY,'submit_server.crt')
SSLCACERTPATH        = os.path.join(CERTIFICATEDIRECTORY,'submit_server_ca.crt')

class ClientListener:
   def __init__(self,
                listenURIs):
      self.clientListeners       = {}
      self.clientListenerSockets = []
      self.remoteIP              = None
      self.clientChannel         = None
      self.clientSocket          = None
      self.fromClientBuffer      = ""
      self.toClientBuffer        = ""
      self.bufferSize            = 1024

      for listenURI in listenURIs:
         protocol,bindHost,bindPort = self.__parseURL(listenURI)
         if bindPort > 0:
            log("Listening: protocol='%s', host='%s', port=%d" % (protocol,bindHost,bindPort))
            if   protocol == 'tls':
               clientListener = MessageCore(bindHost=bindHost,bindPort=bindPort,bindLabel=listenURI, \
                                            sslKeyPath=SSLKEYPATH,sslCertPath=SSLCERTPATH, \
                                            sslCACertPath=SSLCACERTPATH, \
                                            reuseAddress=True,blocking=False)
            elif protocol == 'tcp':
               clientListener = MessageCore(bindHost=bindHost,bindPort=bindPort,bindLabel=listenURI, \
                                            reuseAddress=True,blocking=False)
            else:
               clientListener = None
               log("Unknown protocol: %s" % (protocol))

            if clientListener:
               if clientListener.isBound():
                  boundSocket = clientListener.boundSocket()
                  self.clientListeners[boundSocket] = clientListener
                  self.clientListenerSockets.append(boundSocket)

      if len(self.clientListeners) == 0:
         log("No listening devices configured")


   def isListening(self):
      return(len(self.clientListeners) > 0)


   def __parseURL(self,
                  url):
      protocol = ""
      host     = ""
      port     = 0
      try:
         protocol,host,port = url.split(':')
         protocol = protocol.lower()
         host     = host.lstrip('/')
         port     = int(port)
      except:
         log("Improper network specification: %s" % (url))

      return(protocol,host,port)


   def __handshake(self,
                   listeningSocket):
      valid = False
      message = "Hello.\n"
      reply = ""

      try:
         # Write the message.
         nSent = self.clientListeners[listeningSocket].sendMessage(self.clientChannel,message)

         # Expect the same message back.
         reply = self.clientListeners[listeningSocket].receiveMessage(self.clientChannel,nSent,nSent)
         if reply == message:
            valid = True
      except Exception, err:
         log("ERROR: Connection handshake failed.  Protocol mismatch?")
         log("handshake(%s): %s" % (message.strip(),reply.strip()))
         log("err = %s" % (str(err)))

      return(valid)


   def handshake(self,
                 listeningSocket):
      valid = self.__handshake(listeningSocket)
      if valid:
         self.clientChannel.setblocking(0)

      return(valid)


   def getInputObjects(self):
      listeningSockets = copy.copy(self.clientListenerSockets)
      if self.clientChannel:
         clientReader = [self.clientChannel]
      else:
         clientReader = []

      return(listeningSockets,clientReader)


   def getOutputObjects(self):
      clientWriter = []
      if self.clientChannel and self.toClientBuffer != "":
         clientWriter = [self.clientChannel]

      return(clientWriter)


   def closeListeningConnections(self):
      for clientListener in self.clientListeners.values():
         boundSocket = clientListener.boundSocket()
         self.clientListenerSockets.remove(boundSocket)
         clientListener.close()


   def acceptClientConnection(self,
                              listeningSocket):
      clientChannel,remoteIP,remotePort = self.clientListeners[listeningSocket].acceptConnectionDetailed(True)
      self.clientSocket  = listeningSocket
      self.clientChannel = clientChannel
      self.remoteIP      = remoteIP


   def isClientConnected(self):
      return(self.clientChannel != None)


   def getRemoteIP(self):
      return(self.remoteIP)


   def receiveClientMessage(self):
      clientMessage = self.clientListeners[self.clientSocket].receiveMessage(self.clientChannel,0,self.bufferSize)
      if   clientMessage == None:
         self.closeClientConnection()
      elif clientMessage == "":
         self.closeClientConnection()
      else:
         self.fromClientBuffer += clientMessage


   def pullClientMessage(self,
                         messageLength):
      if messageLength == 0:
         try:
            nl = self.fromClientBuffer.index('\n')
            message = self.fromClientBuffer[0:nl]
            self.fromClientBuffer = self.fromClientBuffer[nl+1:]
         except:
            message = ""
      else:
         if len(self.fromClientBuffer) >= messageLength:
            message = self.fromClientBuffer[0:messageLength]
            self.fromClientBuffer = self.fromClientBuffer[messageLength:]
         else:
            message = ""

      return(message)


   def pushClientMessage(self,
                         message):
      self.fromClientBuffer = message + self.fromClientBuffer


   def postClientMessage(self,
                         message):
      self.toClientBuffer += message


   def postClientMessageBySize(self,
                               command,
                               message):
      if len(message) > 0:
         self.toClientBuffer += command + " %d\n" % (len(message))
         self.toClientBuffer += message


   def postClientMessagesBySize(self,
                                command,
                                messages):
      text = ""
      for message in messages:
         text += message
      if len(text) > 0:
         self.toClientBuffer += command
         for message in messages:
            self.toClientBuffer += " %d" % (len(message))
         self.toClientBuffer += "\n" + text


   def sendClientMessage(self):
      transmittedLength = self.clientListeners[self.clientSocket].sendMessage(self.clientChannel,self.toClientBuffer)
      self.toClientBuffer = self.toClientBuffer[transmittedLength:]


   def isClientMessagePending(self):
      return(len(self.toClientBuffer) > 0)


   def closeClientConnection(self):
      if self.clientChannel:
         self.clientChannel.close()
         self.clientChannel = None


