# @package      hubzero-submit-distributor
# @file         HarvestLocal.py
# @copyright    Copyright (c) 2012-2020 The Regents of the University of California.
# @license      http://opensource.org/licenses/MIT MIT
#
# Copyright (c) 2012-2020 The Regents of the University of California.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# HUBzero is a registered trademark of The Regents of the University of California.
#
import sys
import os
import re
import shutil
import time
import datetime
import logging

from hubzero.submit.LogMessage   import getLogIDMessage as getLogMessage
from hubzero.submit.JobStatistic import JobStatistic
from hubzero.submit.JobOutput    import JobOutput
from hubzero.submit.HarvestCore  import HarvestCore

class HarvestLocal(HarvestCore):
   def __init__(self,
                remoteMonitors,
                hubUserName,
                hubUserId,
                currentWorkingDirectory,
                batchCommands,
                isParametric,
                runName,
                localJobId,
                instanceId,
                harvestInfo,
                siteInfo,
                timeHistoryLogs):
      HarvestCore.__init__(self,harvestInfo,siteInfo,remoteMonitors,timeHistoryLogs)

      self.logger                = logging.getLogger(__name__)
      self.jobOutput             = JobOutput()
      self.venueMechanism        = 'local'
      self.hubUserName           = hubUserName
      self.hubUserId             = hubUserId
      remoteScriptsDirectory = siteInfo['remoteBinDirectory'].replace('$','\$')
      self.postProcessJob        = os.path.join(remoteScriptsDirectory,batchCommands['postProcessJob'])
      self.transmitResults       = os.path.join(remoteScriptsDirectory,batchCommands['transmitResults'])
      self.cleanupJob            = os.path.join(remoteScriptsDirectory,batchCommands['cleanupJob'])
      self.isParametric          = isParametric
      self.runName               = runName
      self.localJobId            = localJobId
      if instanceId.startswith('WF;'):
         self.instanceId         = "0"
         self.nInstances         = int(instanceId.split(';')[1])
      else:
         self.instanceId         = instanceId
         self.nInstances         = 1
      self.remoteBatchSystem     = siteInfo['remoteBatchSystem']
      self.siteMonitorDesignator = siteInfo['siteMonitorDesignator']
      self.remoteUser            = siteInfo['remoteUser']
      self.logUserRemotely       = siteInfo['logUserRemotely']
      self.sharedUserSpace       = siteInfo['sharedUserSpace']
      self.stageFiles            = siteInfo['stageFiles']
      self.sshIdentityPath       = ""
      self.removeIdentity        = False

      self.jobIndex = int(self.instanceId)
      if not self.jobIndex in self.jobStatistics:
         self.jobStatistics[self.jobIndex] = JobStatistic(self.nCpus)
      self.jobStatistics[self.jobIndex]['jobSubmissionMechanism'] = self.venueMechanism + siteInfo['remoteBatchSystem']
      if self.event:
         self.jobStatistics[self.jobIndex]['event'] = self.event
      if self.venue:
         self.jobStatistics[self.jobIndex]['venue'] = self.venue
      if self.remoteJobIdNumber:
         self.jobStatistics[self.jobIndex]['remoteJobIdNumber'] = self.remoteJobIdNumber

      self.isBatchJob = True
      if self.instanceId != "0":
         self.instanceDirectory = os.path.join(currentWorkingDirectory,self.runName,self.instanceId)
      else:
         self.instanceDirectory = os.path.join(currentWorkingDirectory,self.runName)
      if not os.path.isdir(self.instanceDirectory):
         os.makedirs(self.instanceDirectory)

      self.createIdentityPaths()

      self.workingDirectory        = currentWorkingDirectory
      if siteInfo['remoteScratchDirectory']:
         epoch = int(time.mktime(datetime.datetime.utcnow().timetuple()))
         self.scratchDirectory     = os.path.join(siteInfo['remoteScratchDirectory'],"%s_%s" % (str(epoch),localJobId))
         if self.instanceId != "0":
            self.scratchDirectory += "_%s" % (self.instanceId)
      else:
         self.scratchDirectory     = self.instanceDirectory


   def createIdentityPaths(self):
      if not self.sshIdentityPath:
         if   self.remoteUser.startswith('USER:'):
            self.remoteUser = self.hubUserName
            identityPaths = self.remoteIdentityManager.queryUserIdentities(self.siteInfo['identityManagers'],
                                                                           self.hubUserName)
            if 'personalPKI' in identityPaths:
               self.sshIdentityPath = identityPaths['personalPKI']
               self.logger.log(logging.INFO,getLogMessage("IDENTITY = " + self.sshIdentityPath))
            if 'x509' in identityPaths:
               if not self.x509SubmitProxy:
                  self.x509SubmitProxy = identityPaths['x509']
         elif self.remoteUser.startswith('USER'):
            self.remoteUser = self.hubUserName
            identityPaths = self.remoteIdentityManager.queryUserIdentities(self.siteInfo['identityManagers'],
                                                                           self.hubUserName)
            if 'communitySSH' in identityPaths:
               self.sshIdentityPath = identityPaths['communitySSH']
               self.logger.log(logging.INFO,getLogMessage("IDENTITY = " + self.sshIdentityPath))
            if 'x509' in identityPaths:
               if not self.x509SubmitProxy:
                  self.x509SubmitProxy = identityPaths['x509']
                  self.logger.log(logging.INFO,getLogMessage("IDENTITY = " + self.x509SubmitProxy))
         else:
            identityPaths = self.remoteIdentityManager.queryUserIdentities(self.siteInfo['identityManagers'],
                                                                           self.hubUserName)
            if 'communitySSH' in identityPaths:
               self.sshIdentityPath = identityPaths['communitySSH']
               self.logger.log(logging.INFO,getLogMessage("IDENTITY = " + self.sshIdentityPath))
            if 'x509' in identityPaths:
               if not self.x509SubmitProxy:
                  self.x509SubmitProxy = identityPaths['x509']
                  self.logger.log(logging.INFO,getLogMessage("IDENTITY = " + self.x509SubmitProxy))


   def postProcess(self):
      if not self.jobPostProcessed:
         if self.remoteBatchSystem == 'PEGASUS':
            if self.jobSubmitted:
               if (self.isParametric and self.instanceId == "0") or not self.isParametric:
                  if not self.sharedUserSpace:
                     remoteWorkingDirectory = self.workingDirectory.replace('$','\$')
                  else:
                     remoteWorkingDirectory = self.instanceDirectory
                  remoteScratchDirectory = self.scratchDirectory.replace('$','\$')
                  command = self.postProcessJob + " " + remoteWorkingDirectory + " " + \
                                                        remoteScratchDirectory + " " + \
                                                        "PEGASUS"
                  self.logger.log(logging.INFO,getLogMessage("command = " + command))
                  exitStatus,stdOutput,stdError = self.executeCommand(command)
                  self.logger.log(logging.INFO,getLogMessage(stdOutput))
                  if exitStatus:
                     self.logger.log(logging.INFO,getLogMessage(stdError))

                  self.jobPostProcessed = True


   def retrieveFiles(self):
      exitStatus = 0

      if self.sharedUserSpace and self.stageFiles:
         outTarFile = "%s_%s_output.tar" % (self.localJobId,self.instanceId)
         stageOutTarFiles = []
         stageOutTarFiles.append(os.path.join(self.instanceDirectory,"%s.gz" % (outTarFile)))
         stageOutTarFiles.append(os.path.join(self.instanceDirectory,outTarFile))
         stageOutTarFiles.append("%s.gz" % (outTarFile))
         stageOutTarFiles.append(outTarFile)
         stageOutTarFiles.append(os.path.join(os.path.dirname(self.instanceDirectory),"%s.gz" % (outTarFile)))
         stageOutTarFiles.append(os.path.join(os.path.dirname(self.instanceDirectory),outTarFile))

         fetchResults = True
         for stageOutTarFile in stageOutTarFiles:
            if os.path.isfile(stageOutTarFile):
               fetchResults = False
         if fetchResults and self.jobSubmitted:
            stageOutTarFile = "%s_%s_output.tar" % (self.localJobId,self.instanceId)
            command = self.transmitResults + " " + self.instanceDirectory + " " + \
                                                   stageOutTarFile
            self.logger.log(logging.INFO,getLogMessage("command = " + command))
            exitStatus,stdOutput,stdError = self.executeCommand(command)
            self.logger.log(logging.INFO,getLogMessage(stdOutput))
            if exitStatus:
               self.jobStatistics[self.jobIndex]['exitCode'] = 12
#              if stdOutput != "":
#                 if self.instanceId != "0":
#                    stdFile = os.path.join(self.instanceDirectory,"%s_%s.stdout" % (self.runName,self.instanceId))
#                 else:
#                    stdFile = os.path.join(self.instanceDirectory,"%s.stdout" % (self.runName))
#                 try:
#                    fpStd = open(stdFile,'a')
#                    try:
#                       fpStd.write(stdOutput)
#                    except (IOError,OSError):
#                       self.logger.log(logging.ERROR,getLogMessage("%s could not be written" % (stdFile)))
#                    finally:
#                       fpStd.close()
#                 except (IOError,OSError):
#                    self.logger.log(logging.ERROR,getLogMessage("%s could not be opened" % (stdFile)))

               if stdError != "":
                  if self.instanceId != "0":
                     stdFile = os.path.join(self.instanceDirectory,"%s_%s.stderr" % (self.runName,self.instanceId))
                  else:
                     stdFile = os.path.join(self.instanceDirectory,"%s.stderr" % (self.runName))
                  try:
                     fpStd = open(stdFile,'a')
                     try:
                        fpStd.write(stdError)
                     except (IOError,OSError):
                        self.logger.log(logging.ERROR,getLogMessage("%s could not be written" % (stdFile)))
                     finally:
                        fpStd.close()
                  except (IOError,OSError):
                     self.logger.log(logging.ERROR,getLogMessage("%s could not be opened" % (stdFile)))

         for stageOutTarFile in stageOutTarFiles:
            if os.path.isfile(stageOutTarFile):
               if stageOutTarFile.endswith('.gz'):
                  command = "tar xzmf " + stageOutTarFile + \
                                      " --ignore-case --exclude '*hub-proxy.*' -C " + self.instanceDirectory
               else:
                  command = "tar xmf " + stageOutTarFile + \
                                     " --ignore-case --exclude '*hub-proxy.*' -C " + self.instanceDirectory
               self.logger.log(logging.INFO,getLogMessage("command = " + command))
               exitStatus,stdOutput,stdError = self.executeCommand(command)
               if exitStatus == 0:
                  try:
                     os.remove(stageOutTarFile)
                  except:
                     pass
               else:
                  self.jobStatistics[self.jobIndex]['exitCode'] = 12
                  if stdOutput != "":
                     if self.instanceId != "0":
                        stdFile = os.path.join(self.instanceDirectory,"%s_%s.stdout" % (self.runName,self.instanceId))
                     else:
                        stdFile = os.path.join(self.instanceDirectory,"%s.stdout" % (self.runName))
                     try:
                        fpStd = open(stdFile,'a')
                        try:
                           fpStd.write(stdOutput)
                        except (IOError,OSError):
                           self.logger.log(logging.ERROR,getLogMessage("%s could not be written" % (stdFile)))
                        finally:
                           fpStd.close()
                     except (IOError,OSError):
                        self.logger.log(logging.ERROR,getLogMessage("%s could not be opened" % (stdFile)))

                  if stdError != "":
                     if self.instanceId != "0":
                        stdFile = os.path.join(self.instanceDirectory,"%s_%s.stderr" % (self.runName,self.instanceId))
                     else:
                        stdFile = os.path.join(self.instanceDirectory,"%s.stderr" % (self.runName))
                     try:
                        fpStd = open(stdFile,'a')
                        try:
                           fpStd.write(stdError)
                        except (IOError,OSError):
                           self.logger.log(logging.ERROR,getLogMessage("%s could not be written" % (stdFile)))
                        finally:
                           fpStd.close()
                     except (IOError,OSError):
                        self.logger.log(logging.ERROR,getLogMessage("%s could not be opened" % (stdFile)))
               break

      if not exitStatus:
         self.filesRetrieved = True


   def cleanupFiles(self):
      if not self.filesCleanedup:
         if   self.remoteBatchSystem == 'PEGASUS':
            if self.jobSubmitted:
               if (self.isParametric and self.instanceId == "0") or not self.isParametric:
                  remoteScratchDirectory = self.scratchDirectory.replace('$','\$')
                  if self.isParametric:
                     self.jobOutput.getPegasusStdTimeFiles(self.instanceDirectory,
                                                           remoteScratchDirectory,
                                                           self.timeHistoryLogs['timeResults'])
                  else:
                     self.jobOutput.processPegasusFiles(self.instanceDirectory,
                                                        remoteScratchDirectory,
                                                        self.timeHistoryLogs['timeResults'])

                  command = self.cleanupJob + " " + self.instanceDirectory + " " + \
                                                    remoteScratchDirectory + " " + \
                                                    self.remoteBatchSystem
                  self.logger.log(logging.INFO,getLogMessage("command = " + command))
                  self.logger.log(logging.INFO,getLogMessage(self.executeCommand(command)[1]))

            workDirectoryName = os.path.join(self.instanceDirectory,'work')
            if os.path.isdir(workDirectoryName):
               shutil.rmtree(workDirectoryName,True)
            inProcessDirectoryName = os.path.join(self.instanceDirectory,'InProcessResults')
            if os.path.isdir(inProcessDirectoryName):
               shutil.rmtree(inProcessDirectoryName,True)

            self.cleanupPegasusFiles()
         elif self.remoteBatchSystem == 'BOINC':
            if self.jobSubmitted:
               if self.instanceId != "0":
                  timestampTransferredFile = self.timeHistoryLogs['timestampTransferred']
                  batchTransferredTimeFile = timestampTransferredFile.replace('_'+self.instanceId,'_0')
                  batchTransferredTimePath = os.path.join(self.instanceDirectory,batchTransferredTimeFile)
                  if os.path.exists(batchTransferredTimePath):
                     shutil.copy2(batchTransferredTimePath,os.path.join(self.instanceDirectory,timestampTransferredFile))

               remoteScratchDirectory = self.scratchDirectory.replace('$','\$')
               command = self.cleanupJob + " " + self.instanceDirectory + " " + \
                                                 remoteScratchDirectory + " " + \
                                                 self.remoteBatchSystem
               self.logger.log(logging.INFO,getLogMessage("command = " + command))
               self.logger.log(logging.INFO,getLogMessage(self.executeCommand(command)[1]))

         if self.removeIdentity:
            if os.path.isfile(self.sshIdentityPath):
               try:
                  os.remove(self.sshIdentityPath)
               except:
                  pass

         self.cleanupStageInTarFile()
         self.cleanupScriptTemplateLogFiles()

         if self.sharedUserSpace and self.stageFiles:
            self.cleanupStageOutTarFile()

         self.filesCleanedup = True


