# @package      hubzero-submit-common
# @file         UnboundConnection.py
# @copyright    Copyright (c) 2012-2020 The Regents of the University of California.
# @license      http://opensource.org/licenses/MIT MIT
#
# Copyright (c) 2012-2020 The Regents of the University of California.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# HUBzero is a registered trademark of The Regents of the University of California.
#
import os.path
import time
import json
import logging

from hubzero.submit.LogMessage  import getLogIDMessage as getLogMessage
from hubzero.submit.MessageCore import MessageCore

class UnboundConnection:
   TLSREQUIREMENTNEVER       = 1 << 0
   TLSREQUIREMENTNONE        = 1 << 1
   TLSREQUIREMENTIFAVAILABLE = 1 << 2
   TLSREQUIREMENTALWAYS      = 1 << 3

   def __init__(self,
                tlsRequirement,
                listenURIs=[],
                maximumConnectionPasses=15,
                submitSSLCA=""):
      self.logger              = logging.getLogger(__name__)
      self.connection          = None
      self.activeChannel       = None
      self.bufferSize          = 1024
      self.fromBuffer          = ""
      self.toBuffer            = ""
      self.connectionReadTime  = 0.
      self.connectionWriteTime = 0.

      if   len(listenURIs) > 0:
         filteredURIs = []
         if tlsRequirement == self.TLSREQUIREMENTIFAVAILABLE:
            for listenURI in listenURIs:
               protocol,serverHost,serverPort,serverFile = self.__parseURI(listenURI)
               if protocol == 'tls':
                  filteredURIs.append(listenURI)
            for listenURI in listenURIs:
               protocol,serverHost,serverPort,serverFile = self.__parseURI(listenURI)
               if protocol == 'tcp':
                  filteredURIs.append(listenURI)
         else:
            for listenURI in listenURIs:
               protocol,serverHost,serverPort,serverFile = self.__parseURI(listenURI)
               if   tlsRequirement == self.TLSREQUIREMENTNEVER:
                  if protocol == 'tcp':
                     filteredURIs.append(listenURI)
               elif tlsRequirement == self.TLSREQUIREMENTNONE:
                  filteredURIs.append(listenURI)
               elif tlsRequirement == self.TLSREQUIREMENTALWAYS:
                  if protocol == 'tls':
                     filteredURIs.append(listenURI)

         if len(filteredURIs) > 0:
            delay                      = 0.
            userAborted                = False
            nConnectionPasses          = 0
            logOpenChannelErrorMessage = True
            while not userAborted and not self.activeChannel and nConnectionPasses < maximumConnectionPasses:
               time.sleep(delay)
               nConnectionPasses += 1
               for listenURI in filteredURIs:
                  protocol,serverHost,serverPort,serverFile = self.__parseURI(listenURI)
                  if protocol:
                     if   protocol == 'tls':
                        message = "Connecting: protocol='%s', host='%s', port=%d" % (protocol,serverHost,serverPort)
                        self.logger.log(logging.INFO,getLogMessage(message))
                        self.connection = MessageCore(protocol='tls',bindLabel=listenURI,
                                                      sslCACertPath=submitSSLCA,
                                                      listenerHost=serverHost,listenerPort=serverPort)
                     elif protocol == 'tcp':
                        message = "Connecting: protocol='%s', host='%s', port=%d" % (protocol,serverHost,serverPort)
                        self.logger.log(logging.INFO,getLogMessage(message))
                        self.connection = MessageCore(protocol='tcp',bindLabel=listenURI,
                                                      sslCACertPath=submitSSLCA,
                                                      listenerHost=serverHost,listenerPort=serverPort)
                     elif protocol == 'file':
                        message = "Connecting: protocol='%s', file='%s'" % (protocol,serverFile)
                        self.logger.log(logging.INFO,getLogMessage(message))
                        if os.path.exists(serverFile):
                           self.connection = MessageCore(protocol='file',listenerFile=serverFile,
                                                         bindLabel="UD:%s" % (serverFile))
                        else:
                           self.logger.log(logging.ERROR,getLogMessage("Connecting file %s missing" % (serverFile)))
                     else:
                        self.logger.log(logging.ERROR,getLogMessage("Unknown protocol: %s" % (protocol)))

                     if self.connection:
                        try:
                           self.activeChannel = self.connection.openListenerChannel(logErrorMessage=logOpenChannelErrorMessage,
                                                                                    recordTraceback=True)
                           if self.activeChannel:
                              if self.connection.initiateHandshake(self.activeChannel):
                                 self.bufferSize = self.connection.getDefaultBufferSize()
                                 self.connectionReadTime  = time.time()
                                 self.connectionWriteTime = self.connectionReadTime
                                 break
                              else:
                                 self.connection.closeListenerChannel(self.activeChannel)
                                 self.connection    = None
                                 self.activeChannel = None
                           else:
                              self.connection = None
                        except SystemExit:
                           self.connection    = None
                           self.activeChannel = None
                           userAborted        = True
                           break

               delay = 10.
               logOpenChannelErrorMessage = False
         else:
            self.logger.log(logging.ERROR,getLogMessage("No connections to be configured"))
      else:
         self.logger.log(logging.ERROR,getLogMessage("No connections to be configured"))


   def __parseURI(self,
                  uri):
      protocol = ""
      host     = ""
      port     = 0
      filePath = ""
      try:
         parts = uri.split(':')
         if   len(parts) == 3:
            protocol,host,port = parts
            protocol = protocol.lower()
            host     = host.lstrip('/')
            port     = int(port)
         elif len(parts) == 2:
            protocol,filePath = parts
            protocol = protocol.lower()
            filePath = filePath.replace('/','',2)
      except:
         protocol = ""
         self.logger.log(logging.ERROR,getLogMessage("Improper network specification: %s" % (uri)))

      return(protocol,host,port,filePath)


   def isConnected(self):
      return(self.activeChannel != None)


   def closeConnection(self):
      if self.activeChannel:
         self.activeChannel.close()
         self.activeChannel = None


   def getConnectionReadTime(self):
      return(self.connectionReadTime)


   def getConnectionWriteTime(self):
      return(self.connectionWriteTime)


   def getInputObject(self):
      activeReader = []
      if self.activeChannel:
         activeReader = [self.activeChannel]

      return(activeReader)


   def getOutputObject(self):
      activeWriter = []
      if self.activeChannel and self.toBuffer != "":
         activeWriter = [self.activeChannel]

      return(activeWriter)


   def sendMessage(self):
      transmittedLength = self.connection.sendMessage(self.activeChannel,self.toBuffer)
      if   transmittedLength > 0:
#        self.logger.log(logging.DEBUG,getLogMessage("sendMessage(): %s" % (self.toBuffer[0:transmittedLength])))
         self.toBuffer = self.toBuffer[transmittedLength:]
         self.connectionWriteTime = time.time()
      elif transmittedLength < 0:
         self.closeConnection()
#        self.logger.log(logging.DEBUG,getLogMessage("sendMessage(): closeConnection"))


   def receiveMessage(self):
      message = self.connection.receiveMessage(self.activeChannel,0,self.bufferSize)
      if message == None:
         self.closeConnection()
#        self.logger.log(logging.DEBUG,getLogMessage("sendMessage(): closeConnection"))
      else:
         self.fromBuffer += message
         self.connectionReadTime = time.time()
#        self.logger.log(logging.DEBUG,getLogMessage("receiveMessage(): %s" % (message)))


   def pullMessage(self,
                   messageLength):
      if   messageLength == 0:
         try:
            nl = self.fromBuffer.index('\n')
            message = self.fromBuffer[0:nl]
            self.fromBuffer = self.fromBuffer[nl+1:]
         except:
            message = ""
      elif messageLength < 0:
         ml = min(len(self.fromBuffer),-messageLength)
         message = self.fromBuffer[0:ml]
         self.fromBuffer = self.fromBuffer[ml:]
      else:
         if len(self.fromBuffer) >= messageLength:
            message = self.fromBuffer[0:messageLength]
            self.fromBuffer = self.fromBuffer[messageLength:]
         else:
            message = ""

      return(message)


   def pushMessage(self,
                   message):
      self.fromBuffer = message + self.fromBuffer


   def postMessage(self,
                   message):
      self.toBuffer += message
#     if not isinstance(self.toBuffer,bytes):
#        self.logger.log(logging.DEBUG,getLogMessage("UnboundConnection:postMessageBySize(not bytes): %s %s %s" % \
#                                                                     (type(message),type(self.toBuffer),message)))


   def postMessageBySize(self,
                         command,
                         message):
      if len(message) > 0:
         self.toBuffer += command + " %d\n" % (len(message))
         self.toBuffer += message
#        if not isinstance(self.toBuffer,bytes):
#           self.logger.log(logging.DEBUG,getLogMessage("UnboundConnection:postMessageBySize(not bytes): %s %s %s" % \
#                                                                        (type(message),type(self.toBuffer),message)))


   def postMessagesBySize(self,
                          command,
                          messages):
      text = ""
      for message in messages:
         text += message
      if len(text) > 0:
         self.toBuffer += command
         for message in messages:
            self.toBuffer += " %d" % (len(message))
         self.toBuffer += "\n" + text


   def postJsonMessage(self,
                       jsonObject):
      if jsonObject:
         try:
            message = json.dumps(jsonObject)
         except TypeError:
            self.logger.log(logging.ERROR,getLogMessage("JSON object %s could not be encoded" % (jsonObject)))
         else:
            if len(message) > 0:
#              if not isinstance(message,bytes):
#                 self.logger.log(logging.DEBUG,getLogMessage("UnboundConnection:postMessageBySize(not bytes): %s %s" % \
#                                                                                               (type(message),message)))
               self.postMessageBySize('json',message)


   def isMessagePending(self):
      return(len(self.toBuffer) > 0)


   def logPendingMessage(self,
                         logMessage):
      if self.isMessagePending():
         self.logger.log(logging.INFO,getLogMessage("logPendingMessage(%s): %s" % (logMessage,self.toBuffer)))


