# @package      hubzero-submit-server
# @file         BoundConnections.py
# @author       Steven Clark <clarks@purdue.edu>
# @copyright    Copyright (c) 2012-2015 HUBzero Foundation, LLC.
# @license      http://opensource.org/licenses/MIT MIT
#
# Copyright (c) 2012-2015 HUBzero Foundation, LLC.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# HUBzero is a registered trademark of HUBzero Foundation, LLC.
#
import time
import logging
import json
import math
import socket

from hubzero.submit.LogMessage  import getLogIDMessage as getLogMessage
from hubzero.submit.MessageCore import MessageCore

class BoundConnections:
   def __init__(self,
                listenURI,
                logConnection=True,
                submitSSLcert="",
                submitSSLkey="",
                submitSSLCA="",
                bufferSize=128):
      self.logger                = logging.getLogger(__name__)
      self.listener              = None
      self.boundSocket           = None
      self.activeChannels        = []
      self.remoteIP              = {}
      self.remoteHost            = {}
      self.connectionCheckedTime = {}
      self.bufferSize            = bufferSize
      self.fromBuffer            = {}
      self.toBuffer              = {}
      self.logConnection         = logConnection

      protocol,bindHost,bindPort,bindFile = self.__parseURI(listenURI)
      if   protocol == 'tls':
         if bindPort > 0:
            self.logger.log(logging.INFO,getLogMessage("Listening: protocol='%s', host='%s', port=%d" % \
                                                                       (protocol,bindHost,bindPort)))
            listener = MessageCore(protocol='tls',
                                   bindHost=bindHost,bindPort=bindPort,bindLabel=listenURI,
                                   sslKeyPath=submitSSLkey,sslCertPath=submitSSLcert,
                                   sslCACertPath=submitSSLCA,
                                   reuseAddress=True,blocking=False,
                                   defaultBufferSize=self.bufferSize)
      elif protocol == 'tcp':
         if bindPort > 0:
            self.logger.log(logging.INFO,getLogMessage("Listening: protocol='%s', host='%s', port=%d" % \
                                                                       (protocol,bindHost,bindPort)))
            listener = MessageCore(protocol='tcp',
                                   bindHost=bindHost,bindPort=bindPort,bindLabel=listenURI,
                                   reuseAddress=True,blocking=False,
                                   defaultBufferSize=self.bufferSize)
      elif protocol == 'file':
         if bindFile:
            self.logger.log(logging.INFO,getLogMessage("Listening: '%s'" % (bindFile)))
            listener = MessageCore(protocol='file',
                                   bindFile=bindFile,
                                   bindLabel="UD:%s" % (bindFile),
                                   defaultBufferSize=self.bufferSize)
      else:
         listener = None
         self.logger.log(logging.ERROR,getLogMessage("Unknown protocol: %s" % (protocol)))

      if listener:
         if listener.isBound():
            self.boundSocket = listener.boundSocket()
            self.listener    = listener
      else:
         self.logger.log(logging.ERROR,getLogMessage("Could not bind connection to: %s" % (listenURI)))


   def __parseURI(self,
                  uri):
      protocol = ""
      host     = ""
      port     = 0
      filePath = ""
      try:
         parts = uri.split(':')
         if   len(parts) == 3:
            protocol,host,port = parts
            protocol = protocol.lower()
            host     = host.lstrip('/')
            port     = int(port)
         elif len(parts) == 2:
            protocol,filePath = parts
            protocol = protocol.lower()
            filePath = filePath.replace('/','',2)
      except:
         protocol = ""
         self.logger.log(logging.ERROR,getLogMessage("Improper network specification: %s" % (uri)))

      return(protocol,host,port,filePath)


   def acceptHandshake(self,
                       channel):
      valid = self.listener.acceptHandshake(channel)
      if valid:
         channelId = channel.fileno()
         self.connectionCheckedTime[channelId] = time.time()

      return(valid)


   def isListening(self):
      return(self.listener != None)


   def isConnected(self,
                   channel):
      channelIsConnected = channel in self.activeChannels
      if channelIsConnected:
         channelId = channel.fileno()
         self.connectionCheckedTime[channelId] = time.time()

      return(channelIsConnected)


   def closeConnection(self,
                       channel):
      if channel in self.activeChannels:
         channelId = channel.fileno()
         self.connectionCheckedTime[channelId] = time.time()
         channel.close()
         self.activeChannels.remove(channel)
         del self.connectionCheckedTime[channelId]
         del self.remoteIP[channelId]
         del self.remoteHost[channelId]
         del self.fromBuffer[channelId]
         del self.toBuffer[channelId]


   def setChannelAndBuffers(self,
                            oldChannel,
                            newChannel,
                            newFromBuffer,
                            newToBuffer):
      fromBuffer = self.fromBuffer[oldChannel]
      toBuffer   = self.toBuffer[oldChannel]

      self.activeChannels.append(newChannel)
      self.fromBuffer[newChannel] = newFromBuffer
      self.toBuffer[newChannel]   = newToBuffer

      return(oldChannel,fromBuffer,toBuffer)


   def closeListeningConnection(self):
      if self.boundSocket:
         self.listener.close()
         self.boundSocket = None
#        self.logger.log(logging.DEBUG,getLogMessage("closeListeningConnection()"))


   def acceptConnection(self):
      connectionAccepted = False
      if self.listener.getProtocol() == 'file':
         channel = self.listener.acceptConnection(logConnection=self.logConnection)
         remoteIP   = None
         remotePort = None
         remoteHost = None
      else:
         channel,remoteIP,remotePort,remoteHost = self.listener.acceptConnection(logConnection=self.logConnection,
                                                                                 determineDetails=True)
      if channel:
         if self.acceptHandshake(channel):
            connectionAccepted = True
            self.activeChannels.append(channel)
            channelId = channel.fileno()
            self.connectionCheckedTime[channelId] = time.time()
            self.remoteIP[channelId]              = remoteIP
            self.remoteHost[channelId]            = remoteHost
            self.fromBuffer[channelId]            = ""
            self.toBuffer[channelId]              = ""
         else:
            self.logger.log(logging.ERROR,getLogMessage("acceptConnection(): %d acceptHandshake failed" % (channel.fileno())))

      return(connectionAccepted)


   def getConnectionCheckedTime(self,
                                channel):
      channelId = channel.fileno()

      return(self.connectionCheckedTime[channelId])


   def getRemoteIP(self,
                   channel):
      channelId = channel.fileno()

      return(self.remoteIP[channelId])


   def getRemoteHost(self,
                     channel):
      channelId = channel.fileno()

      return(self.remoteHost[channelId])


   def getInputObjects(self):
      if self.boundSocket:
         listeningSocket = [self.boundSocket]
      else:
         listeningSocket = []

      if len(self.activeChannels) > 0:
         activeReaders = self.activeChannels
         for channel in self.activeChannels:
            channelId = channel.fileno()
            self.connectionCheckedTime[channelId] = time.time()
      else:
         activeReaders = []

      return(listeningSocket,activeReaders)


   def getOutputObjects(self):
      activeWriters = []
      if self.activeChannels:
         for channel in self.activeChannels:
            channelId = channel.fileno()
            if self.toBuffer[channelId] != "":
               activeWriters.append(channel)
               self.connectionCheckedTime[channelId] = time.time()

      return(activeWriters)


   def sendMessage(self,
                   channel):
      channelId = channel.fileno()
      transmittedLength = self.listener.sendMessage(channel,self.toBuffer[channelId])
      if   transmittedLength > 0:
#        self.logger.log(logging.DEBUG,getLogMessage("sendMessage(%d): %s" % (channelId,self.toBuffer[channelId][0:transmittedLength])))
         self.toBuffer[channelId] = self.toBuffer[channelId][transmittedLength:]
      elif transmittedLength < 0:
         self.closeConnection(channel)
#        self.logger.log(logging.DEBUG,getLogMessage("sendMessage(%d): closeConnection" % (channelId)))


   def receiveMessage(self,
                      channel):
      channelId = channel.fileno()
      message = self.listener.receiveMessage(channel,self.bufferSize,self.bufferSize)
      if message == None:
         self.closeConnection(channel)
#        self.logger.log(logging.DEBUG,getLogMessage("receiveMessage(%d): closeConnection" % (channelId)))
      else:
         self.fromBuffer[channelId] += message
#        self.logger.log(logging.DEBUG,getLogMessage("receiveMessage(%d): %s" % (channelId,message)))


   def pullMessage(self,
                   channel,
                   messageLength):
      channelId = channel.fileno()
      if messageLength == 0:
         if self.listener.isBlockingChannel(channel):
            messageLength = self.bufferSize

      if   messageLength == 0:
         try:
            nl = self.fromBuffer[channelId].index('\n')
            message = self.fromBuffer[channelId][0:nl]
            self.fromBuffer[channelId] = self.fromBuffer[channelId][nl+1:]
         except:
            message = ""
      elif messageLength < 0:
         ml = min(len(self.fromBuffer[channelId]),-messageLength)
         message = self.fromBuffer[channelId][0:ml]
         self.fromBuffer[channelId] = self.fromBuffer[channelId][ml:]
      else:
         if len(self.fromBuffer[channelId]) >= messageLength:
            message = self.fromBuffer[channelId][0:messageLength]
            self.fromBuffer[channelId] = self.fromBuffer[channelId][messageLength:]
         else:
            message = ""

      return(message)


   def pushMessage(self,
                   channel,
                   message):
      channelId = channel.fileno()
      self.fromBuffer[channelId] = message + self.fromBuffer[channelId]


   def postMessage(self,
                   channel,
                   message):
      channelId = channel.fileno()
      self.toBuffer[channelId] += message


   def postMessageBySize(self,
                         channel,
                         command,
                         message):
      if len(message) > 0:
         try:
            channelId = channel.fileno()
         except socket.error:
            self.logger.log(logging.ERROR,getLogMessage("postMessageBySize() failed: %s" % (message)))
         else:
            if self.listener.isBlockingChannel(channel):
               messageLength = self.bufferSize * int(math.ceil(float(len(message)) / float(self.bufferSize)))
               messageHeader = "%s %d\n" % (command,messageLength)
               self.toBuffer[channelId] += messageHeader.ljust(self.bufferSize,' ')
               self.toBuffer[channelId] += message.ljust(messageLength,' ')
            else:
               self.toBuffer[channelId] += command + " %d\n" % (len(message))
               self.toBuffer[channelId] += message


   def postMessagesBySize(self,
                          channel,
                          command,
                          messages):
      text = ""
      for message in messages:
         text += message
      if len(text) > 0:
         channelId = channel.fileno()
         self.toBuffer[channelID] += command
         for message in messages:
            self.toBuffer[channelID] += " %d" % (len(message))
         self.toBuffer[channelID] += "\n" + text


   def postJsonMessage(self,
                       channel,
                       jsonObject):
      if jsonObject:
         try:
            message = json.dumps(jsonObject)
         except TypeError:
            self.logger.log(logging.ERROR,getLogMessage("JSON object %s could not be encoded" % (jsonObject)))
         else:
            if len(message) > 0:
               self.postMessageBySize(channel,'json',message)


   def isMessagePending(self,
                        channel):
      channelId = channel.fileno()

      return(len(self.toBuffer[channelId]) > 0)


   def logPendingMessage(self,
                         channel,
                         logMessage):
      if self.isMessagePending(channel):
         channelId = channel.fileno()
         self.logger.log(logging.DEBUG,getLogMessage("logPendingMessage(%s): %s" % (logMessage,self.toBuffer[channelId])))


