# @package      hubzero-submit-common
# @file         MessageCore.py
# @author       Steven Clark <clarks@purdue.edu>
# @copyright    Copyright (c) 2012-2015 HUBzero Foundation, LLC.
# @license      http://opensource.org/licenses/MIT MIT
#
# Copyright (c) 2012-2015 HUBzero Foundation, LLC.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# HUBzero is a registered trademark of HUBzero Foundation, LLC.
#
from OpenSSL import SSL
import sys
import os
import socket
import time
import math
import traceback
import logging
from errno import EAGAIN,EWOULDBLOCK,EPIPE,ECONNRESET

from hubzero.submit.LogMessage import getLogIDMessage as getLogMessage

class MessageCore:
   def __init__(self,
                protocol='tcp',
                bindHost="",
                bindPort=0,
                bindFile="",
                bindLabel="",
                sslKeyPath="",
                sslCertPath="",
                sslCACertPath="",
                reuseAddress=False,
                blocking=True,
                defaultBufferSize=None,
                listenerHost="",
                listenerPort=0,
                listenerFile="",
                repeatDelay=5):
      self.logger            = logging.getLogger(__name__)
      self.protocol          = protocol
      self.family            = None
      self.bindHost          = bindHost
      self.bindPort          = bindPort
      self.bindFile          = bindFile
      self.bindLabel         = bindLabel
      self.bindSocket        = None
      self.listenerHost      = listenerHost
      self.listenerPort      = listenerPort
      self.listenerFile      = listenerFile
      self.repeatDelay       = repeatDelay
      self.sslContext        = None
      self.defaultBufferSize = defaultBufferSize

      if   bindPort > 0:
         self.family = socket.AF_INET
      elif bindFile:
         self.family   = socket.AF_UNIX
         self.protocol = 'file'

      if self.family:
         if sslKeyPath and sslCertPath and sslCACertPath:
            if os.access(sslKeyPath,os.R_OK) and os.access(sslCertPath,os.R_OK) and os.access(sslCACertPath,os.R_OK):
               self.sslContext = SSL.Context(SSL.TLSv1_METHOD)
               self.sslContext.use_privatekey_file(sslKeyPath)
               self.sslContext.use_certificate_file(sslCertPath)
               self.sslContext.load_verify_locations(sslCACertPath)
               try:
#https://github.com/pyca/pyopenssl/issues/190  -  python3 addition
                  self.sslContext.set_mode(SSL._lib.SSL_MODE_ENABLE_PARTIAL_WRITE |
                                           SSL._lib.SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
                                           SSL._lib.SSL_MODE_AUTO_RETRY)
               except:
                  pass

         self.bindSocket = self.getMessageSocket(reuseAddress)

         bound = False
         nTry = 0
         while not bound and nTry < 10:
            try:
               nTry += 1
               if   self.family == socket.AF_INET:
                  self.bindSocket.bind((bindHost,bindPort))
               elif self.family == socket.AF_UNIX:
                  self.bindSocket.bind(bindFile)
               self.bindSocket.listen(512)
               if not blocking:
                  self.bindSocket.setblocking(0)
               bound = True
            except:
               time.sleep(repeatDelay)

         if not bound:
            self.bindSocket = None
            if   self.family == socket.AF_INET:
               message = "Can't bind to port %d: %s %s" % (bindPort,sys.exc_info()[0],sys.exc_info()[1])
               self.logger.log(logging.ERROR,getLogMessage(message))
            elif self.family == socket.AF_UNIX:
               message = "Can't bind to file %s: %s %s" % (bindFile,sys.exc_info()[0],sys.exc_info()[1])
               self.logger.log(logging.ERROR,getLogMessage(message))
      else:
         if   listenerPort > 0:
            self.family = socket.AF_INET
         elif listenerFile:
            self.family = socket.AF_UNIX

         if sslCACertPath:
            if os.access(sslCACertPath,os.R_OK):
               self.sslContext = SSL.Context(SSL.TLSv1_METHOD)
               self.sslContext.set_verify(SSL.VERIFY_PEER,self.__verifyCert) # Demand a certificate
               self.sslContext.load_verify_locations(sslCACertPath)
               try:
#https://github.com/pyca/pyopenssl/issues/190  -  python3 addition
                  self.sslContext.set_mode(SSL._lib.SSL_MODE_ENABLE_PARTIAL_WRITE |
                                           SSL._lib.SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
                                           SSL._lib.SSL_MODE_AUTO_RETRY)
               except:
                  pass


   def __verifyCert(self,
                    conn,
                    cert,
                    errnum,
                    depth,
                    ok):
#     log("Got certificate: %s" % (cert.get_subject()))
#     log("CN: %s" % (cert.get_subject().commonName))
#     log("depth: %s" % (depth))
#Got certificate: <X509Name object '/C=US/ST=Indiana/O=HUBzero/OU=nanoHUB/CN=Certificate Authority'>
#CN: Certificate Authority
#depth: 1
#Got certificate: <X509Name object '/C=US/ST=Indiana/O=HUBzero/OU=nanoHUB/CN=Submit Server'>
#CN: Submit Server
#depth: 0
      if errnum == 0:
         if depth != 0:
            isValid = True
         else:
            isValid = cert.get_subject().commonName == self.listenerHost
      else:
         isValid = False

      return(isValid)


   def getDefaultBufferSize(self):
      return(self.defaultBufferSize)


   def getProtocol(self):
      return(self.protocol)


   def getMessageSocket(self,
                        reuseAddress=False):
      sock = socket.socket(self.family,socket.SOCK_STREAM)
      if reuseAddress:
         sock.setsockopt(socket.SOL_SOCKET,socket.SO_REUSEADDR,1)

      if self.protocol == 'tls':
         messageSocket = SSL.Connection(self.sslContext,sock)
      else:
         messageSocket = sock

      return(messageSocket)


   def isBound(self):
      return(self.bindSocket != None)


   def close(self):
      if self.bindSocket:
         self.bindSocket.close()
         self.bindSocket = None


   def boundSocket(self):
      return(self.bindSocket)


   def boundFileDescriptor(self):
      return(self.bindSocket.fileno())


   def acceptConnection(self,
                        logConnection=False,
                        determineDetails=False):
      remoteIP   = None
      remotePort = None
      remoteHost = None
      try:
         channel,details = self.bindSocket.accept()
      except socket.error as e:
         self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
         if e.args[0] in [EAGAIN,EWOULDBLOCK]:
            channel = None
         else:
            raise
      else:
         if   self.family == socket.AF_INET:
            if determineDetails:
               try:
                  remoteIP,remotePort = details
                  remoteHost = socket.gethostbyaddr(remoteIP)[0]
                  if remoteHost:
                     remoteHost = socket.getfqdn(remoteHost)
               except Exception:
                  pass
            if logConnection:
               self.logger.log(logging.INFO,getLogMessage("===================================================="))
               if determineDetails:
                  self.logger.log(logging.INFO,getLogMessage("Connection to %s from %s %s" % (self.bindLabel,details,remoteHost)))
               else:
                  self.logger.log(logging.INFO,getLogMessage("Connection to %s from %s" % (self.bindLabel,details)))
         elif self.family == socket.AF_UNIX:
            if logConnection:
               self.logger.log(logging.INFO,getLogMessage("===================================================="))
               self.logger.log(logging.INFO,getLogMessage("Connection to %s" % (self.bindLabel)))

      if determineDetails:
         return(channel,remoteIP,remotePort,remoteHost)
      else:
         return(channel)


   def openListenerChannel(self,
                           logErrorMessage=True,
                           recordTraceback=False):
      try:
         listenerChannel = self.getMessageSocket()
         if   self.family == socket.AF_INET:
            listenerChannel.connect((self.listenerHost,self.listenerPort))
         elif self.family == socket.AF_UNIX:
            listenerChannel.connect(self.listenerFile)
      except SystemExit:
         if listenerChannel:
            listenerChannel.close()
         listenerChannel = None
         raise
      except socket.gaierror as e:
         if logErrorMessage:
            self.logger.log(logging.ERROR,getLogMessage("openListenerChannel: %s" % (e.args[1])))
         if listenerChannel:
            listenerChannel.close()
         listenerChannel = None
      except Exception as e:
         if e.args[0] == 111:
            if logErrorMessage:
               self.logger.log(logging.ERROR,getLogMessage("openListenerChannel: %s" % (e.args[1])))
         else:
            if recordTraceback:
               self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
         if listenerChannel:
            listenerChannel.close()
         listenerChannel = None
      except:
         if recordTraceback:
            self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
         if listenerChannel:
            listenerChannel.close()
         listenerChannel = None

      return(listenerChannel)


   def closeListenerChannel(self,
                            listenerChannel):
      if listenerChannel:
         listenerChannel.close()


   def isBlockingChannel(self,
                         channel):
      timeout = channel.gettimeout()
      if timeout == None:
         blockingChannel = True
      else:
         blockingChannel = False

      return(blockingChannel)


   def __receiveNonBlockingMessage(self,
                                   channel,
                                   bufferSize):
      message = ""

      while True:
         try:
            messageChunk = channel.recv(bufferSize)
         except socket.error as e:
            # Happens on non-blocking TCP socket when there's nothing to read
            if isinstance(e.args,tuple):
               if   e.args[0] in [ECONNRESET]:
                  # Connection reset by peer
                  if not message:
                     message = None
               elif not e.args[0] in [EAGAIN,EWOULDBLOCK]:
                  self.logger.log(logging.ERROR,getLogMessage("Unexpected error in receiveNonBlockingMessage()"))
                  self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
            else:
               self.logger.log(logging.ERROR,getLogMessage("Unexpected error in receiveNonBlockingMessage()"))
               self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
            break
         except (SSL.WantReadError, SSL.WantWriteError, SSL.WantX509LookupError):
            # Happens on non-blocking TCP/SSL socket when there's nothing to read
            break
         except SSL.SysCallError as e:
            if e.args[0] == -1:
               # Unexpected EOF
               if not message:
                  message = None
               break
         except:
            self.logger.log(logging.ERROR,getLogMessage("Unexpected error in receiveNonBlockingMessage() " + message))
            self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
            if not message:
               message = None
            break
         else:
            if messageChunk:
               message += messageChunk.decode('utf-8')
            else:
               if not message:
                  message = None
               break

      return(message)


   def __receiveBlockingMessage(self,
                                channel,
                                messageLength,
                                bufferSize):
      if messageLength <= 0:
         self.logger.log(logging.ERROR,getLogMessage("Error in receiveBlockingMessage() %d message length" % (messageLength)))
      bytesRemaining = messageLength
      message = ""

      try:
         while bytesRemaining:
            messageChunk = channel.recv(bufferSize)
            messageChunk = messageChunk.decode('utf-8')
            message += messageChunk
            bytesRemaining -= len(messageChunk)
            if messageChunk == "":
               if message != "":
                  self.logger.log(logging.ERROR,getLogMessage("socket connection broken in receiveBlockingMessage()"))
               else:
                  message = None
               break
      except:
         self.logger.log(logging.ERROR,getLogMessage("Unexpected error in receiveBlockingMessage() " + message))
         self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
         if message == "":
            message = None

      return(message)


   def receiveMessage(self,
                      channel,
                      messageLength,
                      bufferSize=128):
      if channel:
         if self.isBlockingChannel(channel):
#           self.logger.log(logging.DEBUG,getLogMessage("MessageCore:receiveMessage(blocked,%d): %d" % \
#                                                                    (bufferSize,channel.fileno())))
            message = self.__receiveBlockingMessage(channel,messageLength,bufferSize)
#           self.logger.log(logging.DEBUG,getLogMessage("MessageCore:receiveMessage(blocked): %d %s" % \
#                                                                       (channel.fileno(),message)))
         else:
            message = self.__receiveNonBlockingMessage(channel,bufferSize)
#           self.logger.log(logging.DEBUG,getLogMessage("MessageCore:receiveMessage(nonblocked): %d %s" % \
#                                                                       (channel.fileno(),message)))
      else:
         message = None

      return(message)


   def __sendNonBlockingMessage(self,
                                channel,
                                message):
      try:
         transmittedLength = channel.send(message)
         if transmittedLength == 0:
            self.logger.log(logging.ERROR,getLogMessage("socket connection broken in sendNonBlockingMessage()"))
            transmittedLength = -1
      except (SSL.WantReadError, SSL.WantWriteError, SSL.WantX509LookupError):
         transmittedLength = 0
      except SSL.Error:
         self.logger.log(logging.ERROR,getLogMessage("Unexpected SSL error in sendNonBlockingMessage(%s)" % (message[0:20])))
         self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
         transmittedLength = 0
      except Exception as e:
         if isinstance(e.args,tuple):
            if not e.args[0] in [EPIPE]:
               self.logger.log(logging.ERROR,getLogMessage("Unexpected error in sendNonBlockingMessage(%s)" % (message[0:20])))
               self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
            else:
               self.logger.log(logging.ERROR,getLogMessage("sendNonBlockingMessage: Broken pipe"))
         else:
            self.logger.log(logging.ERROR,getLogMessage("Unexpected error in sendNonBlockingMessage(%s)" % (message[0:20])))
            self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
         transmittedLength = -1

      return(transmittedLength)


   def __sendBlockingMessage(self,
                             channel,
                             message,
                             fixedBufferSize):
      try:
         if fixedBufferSize > 0:
            fixedBufferMessage = "%-*s" % (fixedBufferSize,message)
            bytesRemaining = fixedBufferSize
            while bytesRemaining:
               transmittedLength = channel.send(fixedBufferMessage[fixedBufferSize-bytesRemaining:])
               bytesRemaining -= transmittedLength
               if transmittedLength == 0:
                  self.logger.log(logging.ERROR,getLogMessage("socket connection broken in sendBlockingMessage()"))
                  transmittedLength = -1
                  break
         else:
            bytesRemaining = len(message)
            while bytesRemaining:
               transmittedLength = channel.send(message[len(message)-bytesRemaining:])
               bytesRemaining -= transmittedLength
               if transmittedLength == 0:
                  self.logger.log(logging.ERROR,getLogMessage("socket connection broken in sendBlockingMessage()"))
                  transmittedLength = -1
                  break
      except SSL.Error as e:
         if isinstance(e.args,tuple):
            if not e.args[0] in [EPIPE]:
               self.logger.log(logging.ERROR,getLogMessage("Unexpected error in sendBlockingMessage(%s)" % (message[0:20])))
               self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
            else:
               self.logger.log(logging.ERROR,getLogMessage("sendBlockingMessage: Broken pipe"))
         else:
            self.logger.log(logging.ERROR,getLogMessage("Unexpected error in sendBlockingMessage(%s)" % (message[0:20])))
            self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
         transmittedLength = -1
      except Exception as e:
         if isinstance(e.args,tuple):
            if not e.args[0] in [EPIPE]:
               self.logger.log(logging.ERROR,getLogMessage("Unexpected error in sendBlockingMessage(%s)" % (message[0:20])))
               self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
            else:
               self.logger.log(logging.ERROR,getLogMessage("sendBlockingMessage: Broken pipe"))
         else:
            self.logger.log(logging.ERROR,getLogMessage("Unexpected error in sendBlockingMessage(%s)" % (message[0:20])))
            self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
         transmittedLength = -1

      return(transmittedLength)


   def sendMessage(self,
                   channel,
                   message,
                   fixedBufferSize=0):
      if channel:
         if self.isBlockingChannel(channel):
#           self.logger.log(logging.DEBUG,getLogMessage("MessageCore:sendMessage(blocked,%d): %d %s" % \
#                                                       (fixedBufferSize,channel.fileno(),message)))
            transmittedLength = self.__sendBlockingMessage(channel,message,fixedBufferSize)
         else:
#           self.logger.log(logging.DEBUG,getLogMessage("MessageCore:sendMessage(nonblocked): %d %s" % \
#                                                              (channel.fileno(),message)))
            transmittedLength = self.__sendNonBlockingMessage(channel,message)
      else:
         transmittedLength = 0

      return(transmittedLength)


   def __acceptHandshake(self,
                         channel,
                         bufferSize):
      valid = False
      message = "SUBMIT %d" % (bufferSize)
      message = message.ljust(32,' ')
      reply = ""

      try:
         # Write the message.
         nSent = self.sendMessage(channel,message)

         if nSent > 0:
            # Expect the same message back.
            reply = self.receiveMessage(channel,nSent,nSent)
            if reply:
               if len(reply) == len(message):
                  if reply[0:7] == message[0:7]:
                     valid = True
                     if reply.strip()[-1] == '1':
                        channel.setblocking(1)
                     else:
                        channel.setblocking(0)
            else:
               self.logger.log(logging.ERROR,getLogMessage("Connection acceptHandshake failed.  No response."))
      except Exception:
         self.logger.log(logging.ERROR,getLogMessage("Connection acceptHandshake failed.  Protocol mismatch?"))
         self.logger.log(logging.ERROR,getLogMessage("acceptHandshake(%s): %s" % (message.strip(),str(reply).strip())))
         self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))

      return(valid)


   def acceptHandshake(self,
                       channel):
      valid = self.__acceptHandshake(channel,self.defaultBufferSize)

      return(valid)


   def __initiateHandshake(self,
                           channel,
                           blocking):
      valid = False
      if blocking:
         message = "SUBMIT 1"
      else:
         message = "SUBMIT 0"
      message = message.ljust(32,' ')
      reply = ""

      try:
         # Write the message.
         nSent = self.sendMessage(channel,message)

         if nSent > 0:
            # Expect the same message back.
            reply = self.receiveMessage(channel,nSent,nSent)
            if len(reply) == len(message):
               if reply[0:7] == message[0:7]:
                  valid = True
                  if blocking:
                     channel.setblocking(1)
                  else:
                     channel.setblocking(0)
                  self.defaultBufferSize = int(reply.split()[-1])
      except Exception:
         self.logger.log(logging.ERROR,getLogMessage("Connection initiateHandshake failed.  Protocol mismatch?"))
         self.logger.log(logging.ERROR,getLogMessage("initiateHandshake(%s): %s" % (message.strip(),str(reply).strip())))
         self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))

      return(valid)


   def initiateHandshake(self,
                         channel,
                         blocking=False):
      valid = self.__initiateHandshake(channel,blocking)

      return(valid)


