# @package      hubzero-submit-common
# @file         UnboundConnection.py
# @author       Steven Clark <clarks@purdue.edu>
# @copyright    Copyright (c) 2012-2014 HUBzero Foundation, LLC.
# @license      http://www.gnu.org/licenses/lgpl-3.0.html LGPLv3
#
# Copyright (c) 2012-2014 HUBzero Foundation, LLC.
#
# This file is part of: The HUBzero(R) Platform for Scientific Collaboration
#
# The HUBzero(R) Platform for Scientific Collaboration (HUBzero) is free
# software: you can redistribute it and/or modify it under the terms of
# the GNU Lesser General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# HUBzero is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# HUBzero is a registered trademark of HUBzero Foundation, LLC.
#

import os.path
import time
import logging

from hubzero.submit.LogMessage  import getLogIDMessage as getLogMessage
from hubzero.submit.MessageCore import MessageCore

class UnboundConnection:
   TLSREQUIREMENTNEVER       = 1 << 0
   TLSREQUIREMENTNONE        = 1 << 1
   TLSREQUIREMENTIFAVAILABLE = 1 << 2
   TLSREQUIREMENTALWAYS      = 1 << 3


   def __init__(self,
                tlsRequirement,
                listenURIs=[],
                maximumConnectionPasses=15,
                submitSSLCA=""):
      self.logger        = logging.getLogger(__name__)
      self.connection    = None
      self.activeChannel = None
      self.bufferSize    = 1024
      self.fromBuffer    = ""
      self.toBuffer      = ""

      if   len(listenURIs) > 0:
         filteredURIs = []
         if tlsRequirement == self.TLSREQUIREMENTIFAVAILABLE:
            for listenURI in listenURIs:
               protocol,serverHost,serverPort,serverFile = self.__parseURI(listenURI)
               if protocol == 'tls':
                  filteredURIs.append(listenURI)
            for listenURI in listenURIs:
               protocol,serverHost,serverPort,serverFile = self.__parseURI(listenURI)
               if protocol == 'tcp':
                  filteredURIs.append(listenURI)
         else:
            for listenURI in listenURIs:
               protocol,serverHost,serverPort,serverFile = self.__parseURI(listenURI)
               if   tlsRequirement == self.TLSREQUIREMENTNEVER:
                  if protocol == 'tcp':
                     filteredURIs.append(listenURI)
               elif tlsRequirement == self.TLSREQUIREMENTNONE:
                  filteredURIs.append(listenURI)
               elif tlsRequirement == self.TLSREQUIREMENTALWAYS:
                  if protocol == 'tls':
                     filteredURIs.append(listenURI)

         if len(filteredURIs) > 0:
            delay             = 0.
            userAborted       = False
            nConnectionPasses = 0
            while not userAborted and not self.activeChannel and nConnectionPasses < maximumConnectionPasses:
               time.sleep(delay)
               nConnectionPasses += 1
               for listenURI in filteredURIs:
                  protocol,serverHost,serverPort,serverFile = self.__parseURI(listenURI)
                  if protocol:
                     if   protocol == 'tls':
                        message = "Connecting: protocol='%s', host='%s', port=%d" % (protocol,serverHost,serverPort)
                        self.logger.log(logging.INFO,getLogMessage(message))
                        self.connection = MessageCore(protocol='tls',bindLabel=listenURI,
                                                      sslCACertPath=submitSSLCA,
                                                      listenerHost=serverHost,listenerPort=serverPort)
                     elif protocol == 'tcp':
                        message = "Connecting: protocol='%s', host='%s', port=%d" % (protocol,serverHost,serverPort)
                        self.logger.log(logging.INFO,getLogMessage(message))
                        self.connection = MessageCore(protocol='tcp',bindLabel=listenURI,
                                                      sslCACertPath=submitSSLCA,
                                                      listenerHost=serverHost,listenerPort=serverPort)
                     elif protocol == 'file':
                        message = "Connecting: protocol='%s', file='%s'" % (protocol,serverFile)
                        self.logger.log(logging.INFO,getLogMessage(message))
                        if os.path.exists(serverFile):
                           self.connection = MessageCore(protocol='file',listenerFile=serverFile,
                                                         bindLabel="UD:%s" % (serverFile))
                        else:
                           self.logger.log(logging.ERROR,getLogMessage("Connecting file %s missing" % (serverFile)))
                     else:
                        self.logger.log(logging.ERROR,getLogMessage("Unknown protocol: %s" % (protocol)))

                     if self.connection:
                        try:
                           self.activeChannel = self.connection.openListenerChannel(True)
                           if self.activeChannel:
                              if self.handshake():
                                 break
                              else:
                                 self.connection.closeListenerChannel(self.activeChannel)
                                 self.connection    = None
                                 self.activeChannel = None
                           else:
                              self.connection = None
                        except SystemExit:
                           self.connection    = None
                           self.activeChannel = None
                           userAborted        = True
                           break

               delay = 10.
         else:
            self.logger.log(logging.ERROR,getLogMessage("No connections to be configured"))
      else:
         self.logger.log(logging.ERROR,getLogMessage("No connections to be configured"))


   def __parseURI(self,
                  uri):
      protocol = ""
      host     = ""
      port     = 0
      filePath = ""
      try:
         parts = uri.split(':')
         if   len(parts) == 3:
            protocol,host,port = parts
            protocol = protocol.lower()
            host     = host.lstrip('/')
            port     = int(port)
         elif len(parts) == 2:
            protocol,filePath = parts
            protocol = protocol.lower()
            filePath = filePath.replace('/','',2)
      except:
         protocol = ""
         self.logger.log(logging.ERROR,getLogMessage("Improper network specification: %s" % (uri)))

      return(protocol,host,port,filePath)


   def __handshake(self):
      valid = False
      message = "Hello.\n"
      reply  = ""

      try:
         # Write the message.
         nSent = self.connection.sendMessage(self.activeChannel,message)

         if nSent > 0:
            # Expect the same message back.
            reply = self.connection.receiveMessage(self.activeChannel,nSent,nSent)
            if reply == message:
               valid = True
      except Exception, err:
         self.logger.log(logging.ERROR,getLogMessage("ERROR: Connection handshake failed.  Protocol mismatch?"))
         self.logger.log(logging.ERROR,getLogMessage("handshake(%s): %s" % (message.strip(),reply.strip())))
         self.logger.log(logging.ERROR,getLogMessage("err = %s" % (str(err))))

      return(valid)


   def handshake(self):
      valid = self.__handshake()
      if valid:
         self.activeChannel.setblocking(0)

      return(valid)


   def isConnected(self):
      return(self.activeChannel != None)


   def closeConnection(self):
      if self.activeChannel:
         self.activeChannel.close()
         self.activeChannel = None


   def getInputObject(self):
      activeReader = []
      if self.activeChannel:
         activeReader = [self.activeChannel]

      return(activeReader)


   def getOutputObject(self):
      activeWriter = []
      if self.activeChannel and self.toBuffer != "":
         activeWriter = [self.activeChannel]

      return(activeWriter)


   def sendMessage(self):
      transmittedLength = self.connection.sendMessage(self.activeChannel,self.toBuffer)
      if   transmittedLength > 0:
         self.toBuffer = self.toBuffer[transmittedLength:]
      elif transmittedLength < 0:
         self.closeConnection()


   def receiveMessage(self):
      message = self.connection.receiveMessage(self.activeChannel,0,self.bufferSize)
      if message == None:
         self.closeConnection()
      else:
         self.fromBuffer += message


   def pullMessage(self,
                   messageLength):
      if   messageLength == 0:
         try:
            nl = self.fromBuffer.index('\n')
            message = self.fromBuffer[0:nl]
            self.fromBuffer = self.fromBuffer[nl+1:]
         except:
            message = ""
      elif messageLength < 0:
         ml = min(len(self.fromBuffer),-messageLength)
         message = self.fromBuffer[0:ml]
         self.fromBuffer = self.fromBuffer[ml:]
      else:
         if len(self.fromBuffer) >= messageLength:
            message = self.fromBuffer[0:messageLength]
            self.fromBuffer = self.fromBuffer[messageLength:]
         else:
            message = ""

      return(message)


   def pushMessage(self,
                   message):
      self.fromBuffer = message + self.fromBuffer


   def postMessage(self,
                   message):
      self.toBuffer += message


   def postMessageBySize(self,
                         command,
                         message):
      if len(message) > 0:
         self.toBuffer += command + " %d\n" % (len(message))
         self.toBuffer += message


   def postMessagesBySize(self,
                          command,
                          messages):
      text = ""
      for message in messages:
         text += message
      if len(text) > 0:
         self.toBuffer += command
         for message in messages:
            self.toBuffer += " %d" % (len(message))
         self.toBuffer += "\n" + text


   def isMessagePending(self):
      return(len(self.toBuffer) > 0)


   def logPendingMessage(self,
                         logMessage):
      if self.isMessagePending():
         self.logger.log(logging.INFO,getLogMessage("logPendingMessage(%s): %s" % (logMessage,self.toBuffer)))


