# @package      hubzero-submit-common
# @file         MessageConnection.py
# @author       Steven Clark <clarks@purdue.edu>
# @copyright    Copyright (c) 2012-2015 HUBzero Foundation, LLC.
# @license      http://opensource.org/licenses/MIT MIT
#
# Copyright (c) 2012-2015 HUBzero Foundation, LLC.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# HUBzero is a registered trademark of HUBzero Foundation, LLC.
#
import os.path
import time
import math
import json
import select
import logging
from errno import EINTR

from hubzero.submit.LogMessage  import getLogIDMessage as getLogMessage
from hubzero.submit.MessageCore import MessageCore

class MessageConnection:
   def __init__(self,
                listenURI,
                bufferSize=256,
                maximumConnectionPasses=None,
                submitSSLCA=""):
      self.logger                  = logging.getLogger(__name__)
      self.connection              = None
      self.maximumConnectionPasses = maximumConnectionPasses
      self.channel                 = None
      self.bufferSize              = bufferSize
      self.fromBuffer              = ""
      self.toBuffer                = ""

      protocol,serverHost,serverPort,serverFile = self.__parseURI(listenURI)
      if   protocol == 'tls':
#        message = "Connecting: protocol='%s', host='%s', port=%d" % (protocol,serverHost,serverPort)
#        self.logger.log(logging.INFO,getLogMessage(message))
         self.connection = MessageCore(protocol='tls',bindLabel=listenURI,
                                       sslCACertPath=submitSSLCA,
                                       listenerHost=serverHost,listenerPort=serverPort)
      elif protocol == 'tcp':
#        message = "Connecting: protocol='%s', host='%s', port=%d" % (protocol,serverHost,serverPort)
#        self.logger.log(logging.INFO,getLogMessage(message))
         self.connection = MessageCore(protocol='tcp',bindLabel=listenURI,
                                       sslCACertPath=submitSSLCA,
                                       listenerHost=serverHost,listenerPort=serverPort)
      elif protocol == 'file':
#        message = "Connecting: protocol='%s', file='%s'" % (protocol,serverFile)
#        self.logger.log(logging.INFO,getLogMessage(message))
         if os.path.exists(serverFile):
            self.connection = MessageCore(protocol='file',listenerFile=serverFile,
                                          bindLabel="UD:%s" % (serverFile))
         else:
            self.logger.log(logging.ERROR,getLogMessage("Connecting file %s missing" % (serverFile)))
      else:
         self.logger.log(logging.ERROR,getLogMessage("Unknown protocol: %s" % (protocol)))


   def __parseURI(self,
                  uri):
      protocol = ""
      host     = ""
      port     = 0
      filePath = ""
      try:
         parts = uri.split(':')
         if   len(parts) == 3:
            protocol,host,port = parts
            protocol = protocol.lower()
            host     = host.lstrip('/')
            port     = int(port)
         elif len(parts) == 2:
            protocol,filePath = parts
            protocol = protocol.lower()
            filePath = filePath.replace('/','',2)
      except:
         protocol = ""
         self.logger.log(logging.ERROR,getLogMessage("Improper network specification: %s" % (uri)))

      return(protocol,host,port,filePath)


   def isConnected(self):
      return(self.channel != None)


   def closeConnection(self):
      if self.channel:
         self.channel.close()
         self.channel = None


   def getInputObject(self):
      activeReader = []
      if self.channel:
         activeReader = [self.channel]

      return(activeReader)


   def getOutputObject(self):
      activeWriter = []
      if self.channel and self.toBuffer != "":
         activeWriter = [self.channel]

      return(activeWriter)


   def sendMessage(self):
      transmittedLength = self.connection.sendMessage(self.channel,self.toBuffer)
      if   transmittedLength > 0:
         self.toBuffer = self.toBuffer[transmittedLength:]
      elif transmittedLength < 0:
         self.closeConnection()


   def receiveMessage(self):
      message = self.connection.receiveMessage(self.channel,0,self.bufferSize)
      if message == None:
         self.closeConnection()
      else:
         self.fromBuffer += message


   def pullMessage(self,
                   messageLength):
      if   messageLength == 0:
         try:
            nl = self.fromBuffer.index('\n')
            message = self.fromBuffer[0:nl]
            self.fromBuffer = self.fromBuffer[nl+1:]
         except:
            message = ""
      elif messageLength < 0:
         ml = min(len(self.fromBuffer),-messageLength)
         message = self.fromBuffer[0:ml]
         self.fromBuffer = self.fromBuffer[ml:]
      else:
         if len(self.fromBuffer) >= messageLength:
            message = self.fromBuffer[0:messageLength]
            self.fromBuffer = self.fromBuffer[messageLength:]
         else:
            message = ""

      return(message)


   def pushMessage(self,
                   message):
      self.fromBuffer = message + self.fromBuffer


   def postMessage(self,
                   message):
      self.toBuffer += message


   def postMessageBySize(self,
                         command,
                         message):
      if len(message) > 0:
         self.toBuffer += command + " %d\n" % (len(message))
         self.toBuffer += message


   def postMessagesBySize(self,
                          command,
                          messages):
      text = ""
      for message in messages:
         text += message
      if len(text) > 0:
         self.toBuffer += command
         for message in messages:
            self.toBuffer += " %d" % (len(message))
         self.toBuffer += "\n" + text


   def isMessagePending(self):
      return(len(self.toBuffer) > 0)


   def requestJsonExchange(self,
                           messageObject,
                           blocking=False):
      jsonObject = None
      nConnectionPasses = 0
      if self.connection:
         messageReturned            = False
         delay                      = 0.
         userAborted                = False
         logOpenChannelErrorMessage = True
         if self.maximumConnectionPasses:
            maximumConnectionPasses = self.maximumConnectionPasses
         else:
            maximumConnectionPasses = nConnectionPasses+1
         while (not userAborted) and (not messageReturned) and (nConnectionPasses < maximumConnectionPasses):
            time.sleep(delay)
            nConnectionPasses += 1
            try:
               self.channel = self.connection.openListenerChannel(logErrorMessage=logOpenChannelErrorMessage)
               if self.channel:
                  if self.connection.initiateHandshake(self.channel,blocking):
                     self.bufferSize = self.connection.getDefaultBufferSize()
                     try:
                        jsonMessage = json.dumps(messageObject)
                     except TypeError:
                        self.logger.log(logging.ERROR,getLogMessage("JSON object %s could not be encoded" % (messageObject)))
                     else:
                        if blocking:
                           lengthJsonMessage = self.bufferSize * int(math.ceil(float(len(jsonMessage)) / float(self.bufferSize)))
                           if self.connection.sendMessage(self.channel,"json %d\n" % (lengthJsonMessage),self.bufferSize) > 0:
                              if self.connection.sendMessage(self.channel,jsonMessage,lengthJsonMessage) > 0:
                                 responseHeader = self.connection.receiveMessage(self.channel,self.bufferSize,self.bufferSize)
                                 if responseHeader != "":
                                    responseLength = responseHeader.split()[-1]
                                    if int(responseLength) > 0:
                                       jsonMessage = self.connection.receiveMessage(self.channel,
                                                                                    int(responseLength),int(responseLength))
                                       if jsonMessage != "":
                                          try:
                                             jsonObject = json.loads(jsonMessage)
                                          except ValueError:
                                             self.logger.log(logging.ERROR,getLogMessage("JSON object %s could not be decoded" % \
                                                                                                                (jsonMessage)))
                                          finally:
                                             messageReturned = True
                        else:
                           self.postMessageBySize('json',jsonMessage)
                           activityUpdateInterval = 0.
                           while True:
                              activeReaders = self.getInputObject()
                              activeWriters = self.getOutputObject()
                              try:
                                 readyReaders,readyWriters,readyExceptions = select.select(activeReaders,
                                                                                           activeWriters,
                                                                                           [],
                                                                                           activityUpdateInterval)
                              except select.error,err:
                                 if err[0] == EINTR:
                                    readyReaders = []
                                    readyWriters = []
                                 else:
                                    break

                              if messageReturned and (not readyReaders) and (not readyWriters):
                                 break
                              for readyReader in readyReaders:
                                 if readyReader in activeReaders:
                                    self.receiveMessage()

                              message = self.pullMessage(0)
                              while message:
                                 args = message.split()
                                 if args[0] == 'json':
                                    jsonMessageLength = int(args[1])
                                    jsonMessage = self.pullMessage(jsonMessageLength)
                                    if len(jsonMessage) > 0:
                                       try:
                                          jsonObject = json.loads(jsonMessage)
                                       except ValueError:
                                          self.logger.log(logging.ERROR,getLogMessage("JSON object %s could not be decoded" % \
                                                                                                            (jsonMessage)))
                                       finally:
                                          messageReturned = True
                                          break
                                    else:
                                       self.pushMessage(message + '\n')
                                       break
                                 else:
                                    self.logger.log(logging.ERROR,getLogMessage("Discarded message: %s" % (message)))

                                 message = self.pullMessage(0)

                              for readyWriter in readyWriters:
                                 if readyWriter in activeWriters:
                                    self.sendMessage()

                  self.closeConnection()
               else:
                  if logOpenChannelErrorMessage:
                     self.logger.log(logging.ERROR,getLogMessage("Channel could not be allocated"))
            except SystemExit:
               self.closeConnection()
               userAborted = True

            delay = 10.
            logOpenChannelErrorMessage = False
            if not self.maximumConnectionPasses:
               maximumConnectionPasses = nConnectionPasses+1
      else:
         self.logger.log(logging.ERROR,getLogMessage("Connection is not configured"))

      return(nConnectionPasses,jsonObject)


