# @package      hubzero-submit-common
# @file         MessageCore.py
# @author       Steven Clark <clarks@purdue.edu>
# @copyright    Copyright (c) 2012-2014 HUBzero Foundation, LLC.
# @license      http://www.gnu.org/licenses/lgpl-3.0.html LGPLv3
#
# Copyright (c) 2012-2014 HUBzero Foundation, LLC.
#
# This file is part of: The HUBzero(R) Platform for Scientific Collaboration
#
# The HUBzero(R) Platform for Scientific Collaboration (HUBzero) is free
# software: you can redistribute it and/or modify it under the terms of
# the GNU Lesser General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# HUBzero is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# HUBzero is a registered trademark of HUBzero Foundation, LLC.
#

from OpenSSL import SSL
import sys
import os
import socket
import time
import traceback
import logging
from errno import EAGAIN,EWOULDBLOCK,EPIPE,ECONNRESET

from hubzero.submit.LogMessage import getLogIDMessage as getLogMessage

class MessageCore:
   def __init__(self,
                protocol='tcp',
                bindHost="",
                bindPort=0,
                bindFile="",
                bindLabel="",
                sslKeyPath="",
                sslCertPath="",
                sslCACertPath="",
                reuseAddress=False,
                blocking=True,
                listenerHost="",
                listenerPort=0,
                listenerFile="",
                repeatDelay=5):
      self.logger       = logging.getLogger(__name__)
      self.protocol     = protocol
      self.family       = None
      self.bindHost     = bindHost
      self.bindPort     = bindPort
      self.bindFile     = bindFile
      self.bindLabel    = bindLabel
      self.bindSocket   = None
      self.listenerHost = listenerHost
      self.listenerPort = listenerPort
      self.listenerFile = listenerFile
      self.repeatDelay  = repeatDelay
      self.sslContext   = None

      if   bindPort > 0:
         self.family = socket.AF_INET
      elif bindFile:
         self.family   = socket.AF_UNIX
         self.protocol = 'file'

      if self.family:
         if sslKeyPath and sslCertPath and sslCACertPath:
            if os.access(sslKeyPath,os.R_OK) and os.access(sslCertPath,os.R_OK) and os.access(sslCACertPath,os.R_OK):
               self.sslContext = SSL.Context(SSL.TLSv1_METHOD)
               self.sslContext.use_privatekey_file(sslKeyPath)
               self.sslContext.use_certificate_file(sslCertPath)
               self.sslContext.load_verify_locations(sslCACertPath)

         self.bindSocket = self.getMessageSocket(reuseAddress)

         bound = False
         nTry = 0
         while not bound and nTry < 10:
            try:
               nTry += 1
               if   self.family == socket.AF_INET:
                  self.bindSocket.bind((bindHost,bindPort))
               elif self.family == socket.AF_UNIX:
                  self.bindSocket.bind(bindFile)
               self.bindSocket.listen(512)
               if not blocking:
                  self.bindSocket.setblocking(0)
               bound = True
            except:
               time.sleep(repeatDelay)

         if not bound:
            self.bindSocket = None
            if   self.family == socket.AF_INET:
               message = "Can't bind to port %d: %s %s" % (bindPort,sys.exc_info()[0],sys.exc_info()[1])
               self.logger.log(logging.ERROR,getLogMessage(message))
            elif self.family == socket.AF_UNIX:
               message = "Can't bind to file %s: %s %s" % (bindFile,sys.exc_info()[0],sys.exc_info()[1])
               self.logger.log(logging.ERROR,getLogMessage(message))
      else:
         if   listenerPort > 0:
            self.family = socket.AF_INET
         elif listenerFile:
            self.family = socket.AF_UNIX

         if sslCACertPath:
            if os.access(sslCACertPath,os.R_OK):
               self.sslContext = SSL.Context(SSL.TLSv1_METHOD)
               self.sslContext.set_verify(SSL.VERIFY_PEER,self.__verifyCert) # Demand a certificate
               self.sslContext.load_verify_locations(sslCACertPath)


   def __verifyCert(self,
                    conn,
                    cert,
                    errnum,
                    depth,
                    ok):
#     log("Got certificate: %s" % (cert.get_subject()))
#     log("CN: %s" % (cert.get_subject().commonName))
#     log("depth: %s" % (depth))
#Got certificate: <X509Name object '/C=US/ST=Indiana/O=HUBzero/OU=nanoHUB/CN=Certificate Authority'>
#CN: Certificate Authority
#depth: 1
#Got certificate: <X509Name object '/C=US/ST=Indiana/O=HUBzero/OU=nanoHUB/CN=Submit Server'>
#CN: Submit Server
#depth: 0
      if errnum == 0:
         if depth != 0:
            isValid = True
         else:
            isValid = cert.get_subject().commonName == self.listenerHost
      else:
         isValid = False

      return(isValid)


   def getProtocol(self):
      return(self.protocol)


   def getMessageSocket(self,
                        reuseAddress=False):
      sock = socket.socket(self.family,socket.SOCK_STREAM)
      if reuseAddress:
         sock.setsockopt(socket.SOL_SOCKET,socket.SO_REUSEADDR,1)

      if self.protocol == 'tls':
         messageSocket = SSL.Connection(self.sslContext,sock)
      else:
         messageSocket = sock

      return(messageSocket)


   def isBound(self):
      return(self.bindSocket != None)


   def close(self):
      if self.bindSocket:
         self.bindSocket.close()
         self.bindSocket = None


   def boundSocket(self):
      return(self.bindSocket)


   def boundFileDescriptor(self):
      return(self.bindSocket.fileno())


   def acceptConnection(self,
                        logConnection=False,
                        determineDetails=False):
      remoteIP   = None
      remotePort = None
      remoteHost = None
      try:
         channel,details = self.bindSocket.accept()
      except socket.error,err:
         self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
         if err[0] in [EAGAIN,EWOULDBLOCK]:
            channel = None
         else:
            raise
      else:
         if   self.family == socket.AF_INET:
            if determineDetails:
               try:
                  remoteIP,remotePort = details
                  remoteHost = socket.gethostbyaddr(remoteIP)[0]
                  if remoteHost:
                     remoteHost = socket.getfqdn(remoteHost)
               except Exception:
                  pass
            if logConnection:
               self.logger.log(logging.INFO,getLogMessage("===================================================="))
               if determineDetails:
                  self.logger.log(logging.INFO,getLogMessage("Connection to %s from %s %s" % (self.bindLabel,details,remoteHost)))
               else:
                  self.logger.log(logging.INFO,getLogMessage("Connection to %s from %s" % (self.bindLabel,details)))
         elif self.family == socket.AF_UNIX:
            if logConnection:
               self.logger.log(logging.INFO,getLogMessage("===================================================="))
               self.logger.log(logging.INFO,getLogMessage("Connection to %s" % (self.bindLabel)))

      if determineDetails:
         return(channel,remoteIP,remotePort,remoteHost)
      else:
         return(channel)


   def openListenerChannel(self,
                           recordTraceback=False):
      try:
         listenerChannel = self.getMessageSocket()
         if   self.family == socket.AF_INET:
            listenerChannel.connect((self.listenerHost,self.listenerPort))
         elif self.family == socket.AF_UNIX:
            listenerChannel.connect(self.listenerFile)
      except SystemExit:
         if listenerChannel:
            listenerChannel.close()
         listenerChannel = None
         raise
      except Exception,err:
         if err.args[0] == 111:
            self.logger.log(logging.ERROR,getLogMessage("openListenerChannel: %s" % (err.args[1])))
         else:
            if recordTraceback:
               self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
         if listenerChannel:
            listenerChannel.close()
         listenerChannel = None
      except:
         if recordTraceback:
            self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
         if listenerChannel:
            listenerChannel.close()
         listenerChannel = None

      return(listenerChannel)


   def closeListenerChannel(self,
                            listenerChannel):
      if listenerChannel:
         listenerChannel.close()


   def __receiveNonBlockingMessage(self,
                                   channel,
                                   bufferSize):
      message = ""

      while True:
         try:
            messageChunk = channel.recv(bufferSize)
         except socket.error,err:
            # Happens on non-blocking TCP socket when there's nothing to read
            if isinstance(err.args,tuple):
               if   err.args[0] in [ECONNRESET]:
                  # Connection reset by peer
                  if not message:
                     message = None
               elif not err.args[0] in [EAGAIN,EWOULDBLOCK]:
                  self.logger.log(logging.ERROR,getLogMessage("Unexpected error in receiveNonBlockingMessage()"))
                  self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
            else:
               self.logger.log(logging.ERROR,getLogMessage("Unexpected error in receiveNonBlockingMessage()"))
               self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
            break
         except (SSL.WantReadError, SSL.WantWriteError, SSL.WantX509LookupError):
            # Happens on non-blocking TCP/SSL socket when there's nothing to read
            break
         except SSL.SysCallError, (errnum,errstr):
            if errnum == -1:
               # Unexpected EOF
               if not message:
                  message = None
               break
         except:
            self.logger.log(logging.ERROR,getLogMessage("Unexpected error in receiveNonBlockingMessage() " + message))
            self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
            if not message:
               message = None
            break
         else:
            if messageChunk:
               message += messageChunk
            else:
               if not message:
                  message = None
               break

      return(message)


   def __receiveBlockingMessage(self,
                                channel,
                                messageLength,
                                bufferSize):
      bytesRemaining = messageLength
      message = ""

      try:
         while bytesRemaining:
            messageChunk = channel.recv(bufferSize)
            message += messageChunk
            bytesRemaining -= len(messageChunk)
            if messageChunk == "":
               if message != "":
                  self.logger.log(logging.ERROR,getLogMessage("socket connection broken in receiveBlockingMessage()"))
               else:
                  message = None
               break
      except:
         self.logger.log(logging.ERROR,getLogMessage("Unexpected error in receiveBlockingMessage() " + message))
         self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
         if message == "":
            message = None

      return(message)


   def receiveMessage(self,
                      channel,
                      messageLength,
                      bufferSize=128):
      timeout = channel.gettimeout()
      if timeout == None:
         message = self.__receiveBlockingMessage(channel,messageLength,bufferSize)
      else:
         message = self.__receiveNonBlockingMessage(channel,bufferSize)

      return(message)


   def __sendNonBlockingMessage(self,
                                channel,
                                message):
      try:
         transmittedLength = channel.send(message)
         if transmittedLength == 0:
            self.logger.log(logging.ERROR,getLogMessage("socket connection broken in sendNonBlockingMessage()"))
            transmittedLength = -1
      except (SSL.WantReadError, SSL.WantWriteError, SSL.WantX509LookupError):
         transmittedLength = 0
      except Exception,err:
         if isinstance(err.args,tuple):
            if not err[0] in [EPIPE]:
               self.logger.log(logging.ERROR,getLogMessage("Unexpected error in sendNonBlockingMessage(%s)" % (message[0:20])))
               self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
            else:
               self.logger.log(logging.ERROR,getLogMessage("sendNonBlockingMessage: Broken pipe"))
         else:
            self.logger.log(logging.ERROR,getLogMessage("Unexpected error in sendNonBlockingMessage(%s)" % (message[0:20])))
            self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
         transmittedLength = -1

      return(transmittedLength)


   def __sendBlockingMessage(self,
                             channel,
                             message,
                             fixedBufferSize):
      try:
         if fixedBufferSize > 0:
            fixedBufferMessage = "%-*s" % (fixedBufferSize,message)
            bytesRemaining = fixedBufferSize
            while bytesRemaining:
               transmittedLength = channel.send(fixedBufferMessage[fixedBufferSize-bytesRemaining:])
               bytesRemaining -= transmittedLength
               if transmittedLength == 0:
                  self.logger.log(logging.ERROR,getLogMessage("socket connection broken in sendBlockingMessage()"))
                  transmittedLength = -1
                  break
         else:
            bytesRemaining = len(message)
            while bytesRemaining:
               transmittedLength = channel.send(message[len(message)-bytesRemaining:])
               bytesRemaining -= transmittedLength
               if transmittedLength == 0:
                  self.logger.log(logging.ERROR,getLogMessage("socket connection broken in sendBlockingMessage()"))
                  transmittedLength = -1
                  break
      except SSL.Error, err:
         (lib,function,reason) = err.args[0][0]
         self.logger.log(logging.ERROR,getLogMessage("sendBlockingMessage: SSL error, %s" % (reason)))
         transmittedLength = -1
      except Exception,err:
         if isinstance(err.args,tuple):
            if not err[0] in [EPIPE]:
               self.logger.log(logging.ERROR,getLogMessage("Unexpected error in sendBlockingMessage(%s)" % (message[0:20])))
               self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
            else:
               self.logger.log(logging.ERROR,getLogMessage("sendBlockingMessage: Broken pipe"))
         else:
            self.logger.log(logging.ERROR,getLogMessage("Unexpected error in sendBlockingMessage(%s)" % (message[0:20])))
            self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
         transmittedLength = -1

      return(transmittedLength)


   def sendMessage(self,
                   channel,
                   message,
                   fixedBufferSize=0):
      if channel:
#        log("MessageCore:sendMessage(%s)" % (message))
         timeout = channel.gettimeout()
         if timeout == None:
            transmittedLength = self.__sendBlockingMessage(channel,message,fixedBufferSize)
         else:
            transmittedLength = self.__sendNonBlockingMessage(channel,message)
      else:
         transmittedLength = 0

      return(transmittedLength)


   def requestMessageResponse(self,
                              message,
                              messageBufferSize,
                              responseBufferSize,
                              recordTraceback=False):
      nTry = 0
      delay = 0
      posted = False
      messageResponseSocket = None
      while not posted:
         time.sleep(delay)
         try:
            nTry += 1
            messageResponseSocket = self.getMessageSocket()
            if   self.family == socket.AF_INET:
               messageResponseSocket.connect((self.listenerHost,self.listenerPort))
            elif self.family == socket.AF_UNIX:
               messageResponseSocket.connect(self.listenerFile)
            if self.__sendBlockingMessage(messageResponseSocket,message,messageBufferSize) > 0:
               response = self.__receiveBlockingMessage(messageResponseSocket,responseBufferSize,128)
               if response != "":
                  posted = True
         except:
            if recordTraceback:
               self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))

         if messageResponseSocket:
            messageResponseSocket.close()
         delay = self.repeatDelay

      return(nTry,response)


   def requestMessageVariableResponse(self,
                                      message,
                                      messageBufferSize,
                                      responseBufferSize,
                                      recordTraceback=False):
      response = ""
      nTry = 0
      delay = 0
      posted = False
      messageResponseSocket = None
      while not posted:
         time.sleep(delay)
         try:
            nTry += 1
            messageResponseSocket = self.getMessageSocket()
            if   self.family == socket.AF_INET:
               messageResponseSocket.connect((self.listenerHost,self.listenerPort))
            elif self.family == socket.AF_UNIX:
               messageResponseSocket.connect(self.listenerFile)
            if self.__sendBlockingMessage(messageResponseSocket,message,messageBufferSize) > 0:
               responseHeader = self.__receiveBlockingMessage(messageResponseSocket,responseBufferSize,responseBufferSize)
               if responseHeader != "":
                  responseLength = responseHeader.strip()
                  if int(responseLength) > 0:
                     response = self.__receiveBlockingMessage(messageResponseSocket,int(responseLength),int(responseLength))
                     if response != "":
                        posted = True
                  else:
                     posted = True
         except:
            if recordTraceback:
               self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))

         if messageResponseSocket:
            messageResponseSocket.close()
         delay = self.repeatDelay

      return(nTry,response)


   def requestMessageTimestampResponse(self,
                                       message,
                                       messageBufferSize,
                                       responseBufferSize,
                                       recordTraceback=False):
      response = ""
      nTry = 0
      delay = 0
      posted = False
      messageResponseSocket = None
      while not posted:
         time.sleep(delay)
         try:
            nTry += 1
            messageResponseSocket = self.getMessageSocket()
            if   self.family == socket.AF_INET:
               messageResponseSocket.connect((self.listenerHost,self.listenerPort))
            elif self.family == socket.AF_UNIX:
               messageResponseSocket.connect(self.listenerFile)
            if self.__sendBlockingMessage(messageResponseSocket,message,messageBufferSize) > 0:
               responseHeader = self.__receiveBlockingMessage(messageResponseSocket,responseBufferSize,responseBufferSize)
               if responseHeader != "":
                  responseLength,responseTimestamp = responseHeader.strip().split()
                  if int(responseLength) > 0:
                     response = self.__receiveBlockingMessage(messageResponseSocket,int(responseLength),int(responseLength))
                     if response != "":
                        posted = True
                  else:
                     posted = True
         except:
            if recordTraceback:
               self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))

         if messageResponseSocket:
            messageResponseSocket.close()
         delay = self.repeatDelay

      return(nTry,response,responseTimestamp)


