#!/usr/bin/python2.6

"""
Pegasus utility for checking file integrity after transfers

Usage: pegasus-integrity-check [options]
"""

##
#  Copyright 2007-2017 University Of Southern California
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing,
#  software distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
##

import cmd
import errno
import glob
import json
import logging
import optparse
import os
import re
import pprint
import stat
import string
from subprocess import STDOUT
import subprocess
import sys
import time
import traceback
import threading
import tempfile


__author__ = "Mats Rynge <rynge@isi.edu>"



# --- classes -----------------------------------------------------------------

class Singleton(type):
    """Implementation of the singleton pattern"""
    _instances = {}
    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            cls._instances[cls] = \
                super(Singleton, cls).__call__(*args, **kwargs)
            cls.lock = threading.Lock()
        return cls._instances[cls]


class Tools(object):
    """Singleton for detecting and maintaining tools we depend on
    """
    
    __metaclass__ = Singleton
    
    _info = {}

    def find(self, executable, version_arg=None, version_regex=None, path_prepend=None):

        self.lock.acquire()
        try:
            if executable in self._info:
                if self._info[executable] is None:
                    return None
                return self._info[executable]
            
            logger.debug("Trying to detect availability/location of tool: %s"
                         %(executable))

            # initialize the global tool info for this executable
            self._info[executable] = {}
            self._info[executable]['full_path'] = None
            self._info[executable]['version'] = None
            self._info[executable]['version_major'] = None
            self._info[executable]['version_minor'] = None
            self._info[executable]['version_patch'] = None
        
            # figure out the full path to the executable
            path_entries = os.environ["PATH"].split(":")
            if "" in path_entries:
                path_entries.remove("")
            if path_prepend is not None:
                for entry in path_prepend:
                    path_entries.insert(0, entry)
            
            # now walk the path
            full_path = None
            for entry in path_entries:
                full_path = entry + "/" + executable
                if os.path.isfile(full_path) and os.access(full_path, os.X_OK):
                    break
                full_path = None
            
            if full_path == None:
                logger.debug("Command '%s' not found in the current environment"
                            %(executable))
                self._info[executable] = None
                return self._info[executable]
            self._info[executable]['full_path'] = full_path
        
            # version
            if version_regex is None:
                version = "N/A"
            else:
                version = backticks(executable + " " + version_arg + " 2>&1")
                version = version.replace('\n', "")
                re_version = re.compile(version_regex)
                result = re_version.search(version)
                if result:
                    version = result.group(1)
                self._info[executable]['version'] = version
        
            # if possible, break up version into major, minor, patch
            re_version = re.compile("([0-9]+)\.([0-9]+)(\.([0-9]+)){0,1}")
            result = re_version.search(version)
            if result:
                self._info[executable]['version_major'] = int(result.group(1))
                self._info[executable]['version_minor'] = int(result.group(2))
                self._info[executable]['version_patch'] = result.group(4)
            if self._info[executable]['version_patch'] is None or \
               self._info[executable]['version_patch'] == "":
                self._info[executable]['version_patch'] = 0
            else:
                self._info[executable]['version_patch'] = \
                    int(self._info[executable]['version_patch'])
        
            logger.debug("Tool found: %s   Version: %s   Path: %s" 
                        % (executable, version, full_path))
            return self._info[executable]['full_path']
        finally:
            self.lock.release()


    def full_path(self, executable):
        """ Returns the full path to a given executable """
        self.lock.acquire()
        try:
            if executable in self._info and self._info[executable] is not None:
                return self._info[executable]['full_path']
            return None
        finally:
            self.lock.release()

    
    def major_version(self, executable):
        """ Returns the detected major version given executable """
        self.lock.acquire()
        try:
            if executable in self._info and self._info[executable] is not None:
                return self._info[executable]['version_major']
            return None
        finally:
            self.lock.release()
                
    
    def version_comparable(self, executable):
        """ Returns the detected comparable version given executable """
        self.lock.acquire()
        try:
            if executable in self._info and self._info[executable] is not None:
                return "%03d%03d%03d" %(int(self._info[executable]['version_major']), \
                                        int(self._info[executable]['version_minor']), \
                                        int(self._info[executable]['version_patch']))
            return None
        finally:
            self.lock.release()


class TimedCommand(object):
    """ Provides a shell callout with a timeout """
        
    def __init__(self, cmd, env_overrides = {}, timeout_secs = 6*60*60, log_cmd = True, log_outerr = True):
        self._cmd = cmd
        self._env_overrides = env_overrides.copy()
        self._timeout_secs = timeout_secs
        self._log_cmd = log_cmd
        self._log_outerr = log_outerr
        self._process = None
        self._out_file = None
        self._outerr = ""

        # used in exceptions
        self._cmd_for_exc = cmd

    def run(self):
        def target():
            
            # custom environment for the subshell
            sub_env = os.environ.copy()
            for key, value in self._env_overrides.iteritems():
                logger.debug("ENV override: %s = %s" %(key, value))
                sub_env[key] = value
            
            self._process = subprocess.Popen(self._cmd, shell=True, env=sub_env,
                                             stdout=self._out_file, stderr=STDOUT, 
                                             preexec_fn=os.setpgrp)
            self._process.communicate()

        if self._log_cmd or logger.isEnabledFor(logging.DEBUG):
            logger.info(self._cmd)
            # provide a short version in exceptions
            self._cmd_for_exc = re.sub(' .*', ' ...', self._cmd)
            
        sys.stdout.flush()
        
        # temp file for the stdout/stderr
        self._out_file = tempfile.TemporaryFile(prefix="pegasus-transfer-", suffix=".out")
        
        thread = threading.Thread(target=target)
        thread.start()

        thread.join(self._timeout_secs)
        if thread.isAlive():
            # do our best to kill the whole process group
            try:
                # os.killpg did not work in all environments
                #os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
                kill_cmd = "kill -TERM -%d" %(os.getpgid(self._process.pid))
                kp = subprocess.Popen(kill_cmd, shell=True)
                kp.communicate()
                self._process.terminate()
            except:
                pass
            thread.join()
            # log the output
            self._out_file.seek(0)
            stdout = str.strip(self._out_file.read())
            if len(stdout) > 0:
                logger.info(stdout)
            self._out_file.close()
            raise RuntimeError("Command timed out after %d seconds: %s" %(self._timeout_secs, self._cmd_for_exc))
        
        # log the output
        self._out_file.seek(0)
        self._outerr = str.strip(self._out_file.read())
        if self._log_outerr and len(self._outerr) > 0:
            logger.info(self._outerr)
        self._out_file.close()
        
        if self._process.returncode != 0:
            raise RuntimeError("Command exited with non-zero exit code (%d): %s" \
                               %(self._process.returncode, self._cmd_for_exc))


    def get_outerr(self):
        """
        returns the combined stdout and stderr from the command
        """
        return self._outerr
    
    
    def get_exit_code(self):
        """
        returns the exit code from the process
        """
        return self._process.returncode

    

# --- global variables ----------------------------------------------------------------

prog_dir  = os.path.realpath(os.path.join(os.path.dirname(sys.argv[0])))
prog_base = os.path.split(sys.argv[0])[1]   # Name of this program

logger = logging.getLogger("my_logger")

# --- functions ----------------------------------------------------------------


def setup_logger(debug_flag):
    
    # log to the console
    console = logging.StreamHandler()
    
    # default log level - make logger/console match
    logger.setLevel(logging.INFO)
    console.setLevel(logging.INFO)

    # debug - from command line
    if debug_flag:
        logger.setLevel(logging.DEBUG)
        console.setLevel(logging.DEBUG)

    # formatter
    formatter = logging.Formatter("Integrity check: %(message)s")
    console.setFormatter(formatter)
    logger.addHandler(console)
    logger.debug("Logger has been configured")


def backticks(cmd_line):
    """
    what would a python program be without some perl love?
    """
    return subprocess.Popen(cmd_line, shell=True,
                            stdout=subprocess.PIPE).communicate()[0]


def json_object_decoder(obj):
    """
    utility function used by json.load() to parse some known objects into equilvalent Python objects
    """
    if 'type' in obj and obj['type'] == 'transfer':
        t = Transfer()
        # src
        for surl in obj['src_urls']:
            priority = None
            if 'priority' in surl:
                priority = int(surl['priority']) 
            t.add_src(surl['site_label'], surl['url'], priority)
        for durl in obj['dest_urls']:
            priority = None
            if 'priority' in durl:
                priority = int(durl['priority']) 
            t.add_dst(durl['site_label'], durl['url'], priority)
        return t
    elif 'type' in obj and obj['type'] == 'mkdir':
        m = Mkdir()
        m.set_url(obj['target']['site_label'], obj['target']['url'])
        return m
    elif 'type' in obj and obj['type'] == 'remove':
        r = Remove()
        r.set_url(obj['target']['site_label'], obj['target']['url'])
        if 'recursive' in obj['target']:
            r.set_recursive(obj['target']['recursive'])
        return r
    return obj


def read_meta_data(f):
    """
    Reads transfers from the new JSON based input format
    """
    data = []
    try:
        #data = json.loads(f, object_hook=json_object_decoder)
        fp = open(f, "r")
        data = json.load(fp)
        fp.close()
    except Exception, err:
        logger.critical('Error parsing the meta data: ' + str(err))
    return data


def myexit(rc):
    """
    system exit without a stack trace
    """
    try:
        sys.exit(rc)
    except SystemExit:
        sys.exit(rc)


def generate_sha256(fname):
    """
    Generates a sha256 hash for the given file
    """

    tools = Tools()
    tools.find("openssl", "version", "([0-9]+\.[0-9]+\.[0-9]+)")
    tools.find("sha256sum", "--version", "([0-9]+\.[0-9]+)")
  
    if tools.full_path("openssl"):
        cmd = tools.full_path("openssl")
        if tools.version_comparable("openssl") < "001001000":
            cmd += " sha -sha256"
        else:
            cmd += " sha256"
        cmd += " " + fname
    elif tools.full_path("sha256sum"):
        cmd = tools.full_path("sha256sum")
        cmd += " " + fname + " | sed 's/ .*//'"
    else:
        logger.error("openssl or sha256sum not found!")
        return None

    # generate the checksum
    tc = TimedCommand(cmd, log_cmd=False, log_outerr=False)
    tc.run()
    if tc.get_exit_code() != 0:
        logger.error("Unable to determine sha256: " + tc.get_outerr())
        return None

    sha256 = tc.get_outerr()
    sha256 = sha256.strip()
    sha256 = sha256[-64:]
    if len(sha256) != 64:
        logger.warn("Unable to determine sha256 of " + fname)
  
    return sha256


def check_integrity(fname, meta_data):
    """
    Checks the integrity of a file given a set of metadata
    """

    current_sha256 = generate_sha256(fname)
    
    # find the expected checksum in the metadata
    expected_sha256 = None
    for entry in meta_data:
        if entry["_id"] == fname:
            expected_sha256 = entry["_attributes"]["checksum.value"]
    if expected_sha256 is None:
        logger.error("No checksum in the meta data for " + fname)
        return False
    
    # now compare them
    if current_sha256 != expected_sha256:
        logger.error("%s: Expected checksum (%s) does not match the calculated checksum (%s)" \
                     %(fname, expected_sha256, current_sha256))
        return False
    
    logger.info("%s: Checksum verified" %(fname))

    return True


# --- main ----------------------------------------------------------------------------

def main():
    global threads
    global stats_start
    global stats_end
    global symlink_file_transfer
    
    # dup stderr onto stdout
    sys.stderr = sys.stdout
    
    # Configure command line option parser
    prog_usage = "usage: %s [options]" % (prog_base)
    parser = optparse.OptionParser(usage=prog_usage)
    
    parser.add_option("", "--generate-sha256", action = "store", dest = "generate_file",
                      help = "Generate a SHA256 hash for the given file.")
    parser.add_option("", "--verify", action = "store", dest = "verify_file",
                      help = "Verify the hash for the given file.")
    parser.add_option("", "--debug", action = "store_true", dest = "debug",
                      help = "Enables debugging output.")
    
    # Parse command line options
    (options, args) = parser.parse_args()
    setup_logger(options.debug)

    # sanity checks
    if (options.generate_file is None and options.verify_file is None) or \
       (options.generate_file is not None and options.verify_file is not None):
        logger.error("One, and only one, of --generate and --verify needs to be specified")
        parser.print_help()
        sys.exit(1)
    
    if options.generate_file:
        results = generate_sha256(options.generate_file)
        if not results:
            myexit(1)
        print results

    if options.verify_file:
        # read all the .meta files in the current working dir
        meta_data = []
        for meta_file in glob.glob("*.meta"):
            logger.debug("Loading metadata from %s" % (meta_file))
            all_md = read_meta_data(meta_file)
            for entry in all_md:
                meta_data.append(entry)
    
        if options.debug: 
            pprint.PrettyPrinter(indent=4).pprint(meta_data)
    
        # now check the file
        results = check_integrity(options.verify_file, meta_data)
        if not results:
            myexit(1)
        
    myexit(0)


if __name__ == "__main__":
    main()
    

