#!/usr/bin/env python

"""
Pegasus utility for removing of files during workflow enactment

Usage: pegasus-cleanup [options]

"""

##
#  Copyright 2007-2011 University Of Southern California
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing,
#  software distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
##

import os
import re
import sys
import errno
import logging
import optparse
import tempfile
import subprocess
import signal
import string
import stat
import time
from collections import deque
import ConfigParser

try:
    # Python 3.0 and later
    from urllib.parse import unquote
except ImportError:
    # Fall back to Python 2's urllib
    from urllib import unquote


__author__ = "Mats Rynge <rynge@isi.edu>"

# --- regular expressions -------------------------------------------------------------

re_parse_comment = re.compile(r'^# +[\w]+ +([\w-]+)')
re_parse_url = re.compile(r'([\w]+)://([\w\.\-:@]*)(/[\S]*)')

# --- classes -------------------------------------------------------------------------

class URL:

    sitename   = ""
    proto      = ""
    host       = ""
    path       = ""

    def set_sitename(self, sitename):
        # the site name is used to match against shell variables, so we have
        # have to replace dashes with underscores (as we do in the planner)
        self.sitename = string.replace(sitename, "-", "_")

    def set_url(self, url):
        self.proto, self.host, self.path = self.parse_url(url)
    
    def parse_url(self, url):

        proto = ""
        host  = ""
        path  = ""

        # the url should be URL decoded to work with our shell callouts
        url = unquote(url)

        # default protocol is file://
        if string.find(url, ":") == -1:
            logger.debug("URL without protocol (" + url + ") - assuming file://")
            url = "file://" + url
        
        # translate symlink URLs to file:
        if string.find(url, "symlink:") == 0:
            url = re.sub("^symlink:", "file:", url)

        # file url is a special cases as it can contain relative paths and env vars
        if string.find(url, "file:") == 0:
            proto = "file"
            # file urls can either start with file://[\w]*/ or file: (no //)
            path = re.sub("^file:(//[\w\.\-:@]*)?", "", url)
            path = expand_env_vars(path)
            return proto, host, path

        # other than file urls
        r = re_parse_url.search(url)
        if not r:
            raise RuntimeError("Unable to parse URL: %s" % (url))
        
        # Parse successful
        proto = r.group(1)
        host = r.group(2)
        path = r.group(3)
        
        # no double slashes in urls
        path = re.sub('//+', '/', path)
        
        return proto, host, path

    def url(self):
        return "%s://%s%s" % (self.proto, self.host, self.path)
    
    def url_dirname(self):
        dn = os.path.dirname(self.path)
        return "%s://%s%s" % (self.proto, self.host, dn)


class Alarm(Exception):
    pass


# --- global variables ----------------------------------------------------------------

prog_base = os.path.split(sys.argv[0])[1]   # Name of this program
prog_dir = os.path.normpath(os.path.join(os.path.dirname(sys.argv[0]))) # path to bin 

logger = logging.getLogger("my_logger")

# this is the map of what tool to use for a given protocol pair (src, dest)
tool_map = {}
tool_map['file'    ] = 'rm'
tool_map['ftp'     ] = 'gsiftp'
tool_map['gsiftp'  ] = 'gsiftp'
tool_map['sshftp'  ] = 'sshftp'
tool_map['irods'   ] = 'irods'
tool_map['s3'      ] = 's3'
tool_map['s3s'     ] = 's3'
tool_map['gs'      ] = 'gs'
tool_map['scp'     ] = 'scp'
tool_map['srm'     ] = 'srm'
tool_map['symlink' ] = 'rm'

tool_info = {}


# --- functions -----------------------------------------------------------------------


def setup_logger(debug_flag):
    
    # log to the console
    console = logging.StreamHandler()
    
    # default log level - make logger/console match
    logger.setLevel(logging.INFO)
    console.setLevel(logging.INFO)

    # debug - from command line
    if debug_flag:
        logger.setLevel(logging.DEBUG)
        console.setLevel(logging.DEBUG)

    # formatter
    formatter = logging.Formatter("%(asctime)s %(levelname)7s:  %(message)s")
    console.setFormatter(formatter)
    logger.addHandler(console)
    logger.debug("Logger has been configured")


def prog_sigint_handler(signum, frame):
    logger.warn("Exiting due to signal %d" % (signum))
    myexit(1)


def alarm_handler(signum, frame):
    raise Alarm


def expand_env_vars(s):
    re_env_var = re.compile(r'\${?([a-zA-Z0-9_]+)}?')
    s = re.sub(re_env_var, get_env_var, s)
    return s


def get_env_var(match):
    name = match.group(1)
    value = ""
    logger.debug("Looking up " + name)
    if name in os.environ:
        value = os.environ[name]
    return value


def myexec(cmd_line, timeout_secs, should_log):
    """
    executes shell commands with the ability to time out if the command hangs
    """
    global delay_exit_code
    if should_log or logger.isEnabledFor(logging.DEBUG):
        logger.info(cmd_line)
    sys.stdout.flush()

    # set up signal handler for timeout
    signal.signal(signal.SIGALRM, alarm_handler)
    signal.alarm(timeout_secs)

    p = subprocess.Popen(cmd_line + " 2>&1", shell=True)
    try:
        stdoutdata, stderrdata = p.communicate()
    except Alarm:
        if sys.version_info >= (2, 6):
            p.terminate()
        raise RuntimeError("Command '%s' timed out after %s seconds" % (cmd_line, timeout_secs))
    rc = p.returncode
    if rc != 0:
        raise RuntimeError("Command '%s' failed with error code %s" % (cmd_line, rc))


def backticks(cmd_line):
    """
    what would a python program be without some perl love?
    """
    return subprocess.Popen(cmd_line, shell=True, stdout=subprocess.PIPE).communicate()[0]


def check_tool(executable, version_arg, version_regex):
    # initialize the global tool info for this executable
    tool_info[executable] = {}
    tool_info[executable]['full_path'] = None
    tool_info[executable]['version'] = None
    tool_info[executable]['version_major'] = None
    tool_info[executable]['version_minor'] = None
    tool_info[executable]['version_patch'] = None

    # figure out the full path to the executable
    full_path = backticks("which " + executable + " 2>/dev/null") 
    full_path = full_path.rstrip('\n')
    if full_path == "":
        logger.info("Command '%s' not found in the current environment" %(executable))
        return
    tool_info[executable]['full_path'] = full_path

    # version
    if version_regex == None:
        version = "N/A"
    else:
        version = backticks(executable + " " + version_arg + " 2>&1")
        version = version.replace('\n', "")
        re_version = re.compile(version_regex)
        result = re_version.search(version)
        if result:
            version = result.group(1)
        tool_info[executable]['version'] = version

    # if possible, break up version into major, minor, patch
    re_version = re.compile("([0-9]+)\.([0-9]+)(\.([0-9]+)){0,1}")
    result = re_version.search(version)
    if result:
        tool_info[executable]['version_major'] = int(result.group(1))
        tool_info[executable]['version_minor'] = int(result.group(2))
        tool_info[executable]['version_patch'] = result.group(4)
    if tool_info[executable]['version_patch'] == None or tool_info[executable]['version_patch'] == "":
        tool_info[executable]['version_patch'] = None
    else:
        tool_info[executable]['version_patch'] = int(tool_info[executable]['version_patch'])

    logger.info("  %-18s Version: %-7s Path: %s" % (executable, version, full_path))


def check_env_and_tools():
    
    # PATH setup
    path = "/usr/bin:/bin"
    if "PATH" in os.environ:
        path = os.environ['PATH']
    path_entries = path.split(':')
    
    # is /usr/bin in the path?
    if not("/usr/bin" in path_entries):
        path_entries.append("/usr/bin")
        path_entries.append("/bin")
       
    # fink on macos x
    if os.path.exists("/sw/bin") and not("/sw/bin" in path_entries):
        path_entries.append("/sw/bin")
    
    # add our own path (basic way to get to other Pegasus tools)
    my_bin_dir = os.path.normpath(os.path.join(os.path.dirname(sys.argv[0])))
    if not(my_bin_dir in path_entries):
        path_entries.append(my_bin_dir)
          
    # need LD_LIBRARY_PATH for Globus tools
    ld_library_path = ""
    if "LD_LIBRARY_PATH" in os.environ:
        ld_library_path = os.environ['LD_LIBRARY_PATH']
    ld_library_path_entries = ld_library_path.split(':')
    
    # if PEGASUS_HOME is set, prepend it to the PATH (we want it early to override other cruft)
    if "PEGASUS_HOME" in os.environ:
        try:
            path_entries.remove(os.environ['PEGASUS_HOME'] + "/bin")
        except Exception:
            pass
        path_entries.insert(0, os.environ['PEGASUS_HOME'] + "/bin")
    
    # if GLOBUS_LOCATION is set, prepend it to the PATH and LD_LIBRARY_PATH 
    # (we want it early to override other cruft)
    if "GLOBUS_LOCATION" in os.environ:
        try:
            path_entries.remove(os.environ['GLOBUS_LOCATION'] + "/bin")
        except Exception:
            pass
        path_entries.insert(0, os.environ['GLOBUS_LOCATION'] + "/bin")
        try:
            ld_library_path_entries.remove(os.environ['GLOBUS_LOCATION'] + "/lib")
        except Exception:
            pass
        ld_library_path_entries.insert(0, os.environ['GLOBUS_LOCATION'] + "/lib")

    os.environ['PATH'] = ":".join(path_entries)
    os.environ['LD_LIBRARY_PATH'] = ":".join(ld_library_path_entries)
    os.environ['DYLD_LIBRARY_PATH'] = ":".join(ld_library_path_entries)
    logger.info("PATH=" + os.environ['PATH'])
    logger.info("LD_LIBRARY_PATH=" + os.environ['LD_LIBRARY_PATH'])
    
    # irods requires a password hash file
    os.environ['irodsAuthFileName'] = os.getcwd() + "/.irodsA"
    
    # tools we might need later
    check_tool("lcg-del", "--version", "lcg_util-([\.0-9a-zA-Z]+)")
    check_tool("srm-rm", "-version", "srm-copy[ \t]+([\.0-9a-zA-Z]+)")
    check_tool("irm", "-h", "Version[ \t]+([\.0-9a-zA-Z]+)")
    check_tool("pegasus-gridftp", "", None)
    check_tool("pegasus-s3", "help", None)
    check_tool("gsutil", "version", "gsutil version: ([\.0-9a-zA-Z]+)")


def rm(urls):
    """
    removes locally using /bin/rm
    """
    for i, url in enumerate(urls):
        cmd = "/bin/rm"
        if options.ignore_failures:
            cmd += " -f"
        if options.recursive:
            cmd += " -r"
        cmd += " \"%s\"" % (url.path)
        try:
            myexec(cmd, 6*60*60, True)
        except RuntimeError, err:
            logger.error(err)
            raise

def scp(urls):
    """
    removes using ssh+rm
    """
    for i, url in enumerate(urls):
    
        # split up the host into host/port components
        host = re.sub(':.*', '', url.host)
        port = "22"
        r = re.search(':([0-9]+)', url.host)
        if r:
            port = r.group(1)

        cmd = "/usr/bin/ssh"
        key = "SSH_PRIVATE_KEY_" + url.sitename
        if key in os.environ:
            cmd += " -i " + os.environ[key]
        elif "SSH_PRIVATE_KEY" in os.environ:
            cmd += " -i " + os.environ['SSH_PRIVATE_KEY']
        cmd += " -o PasswordAuthentication=no"
        cmd += " -o UserKnownHostsFile=/dev/null"
        cmd += " -o StrictHostKeyChecking=no" + \
               " -p " + port + \
               " " + host + " " + \
               " \"/bin/rm"
        if options.ignore_failures:
            cmd += " -f"
        if options.recursive:
            cmd += " -r"       
        cmd += " " + url.path + "\""
        try:
            myexec(cmd, 6*60*60, True)
        except RuntimeError, err:
            logger.error(err)
            raise

def gsiftp(urls):
    """
    remove files on gridftp servers - delegate run to pegasus-gridftp
    """
 
    try:
        tmp_fd, tmp_name = tempfile.mkstemp(prefix="pegasus-cleanup-", suffix=".lst", dir="/tmp")
        tmp_file = os.fdopen(tmp_fd, "w+b")
    except:
        raise RuntimeError("Unable to create tmp file for pegasus-gridftp cleanup")
    for i, url in enumerate(urls):
        tmp_file.write("%s\n" %(url.url()))

    tmp_file.close()
    
    key = "X509_USER_PROXY_" + urls[0].sitename
    if key in os.environ:
        os.environ["X509_USER_PROXY"] = os.environ[key]
    
    # use pegasus-transfer
    cmd = prog_dir + "/pegasus-gridftp rm -f"
    if options.recursive:
        cmd += " -r"
            
    # make output match our current log level
    if logger.isEnabledFor(logging.DEBUG):
        cmd += " -v"

    cmd += " -i " + tmp_name

    try:
        myexec(cmd, 12*60*60, True)
    except RuntimeError, err:
        logger.error(err)
        os.unlink(tmp_name)
        raise

    os.unlink(tmp_name)


def sshftp(urls):
    """
    remove sshftp based files using globus-url-copy
    """

    if tool_info['globus-url-copy']['full_path'] == None:
        raise RuntimeError("Unable to do sshftp mkdir because globus-url-copy could not be found")
        
    if options.recursive:
        logger.info("Skipping cleanup as recursive deletes are not possible for sshftp:// URLS")
        return

    for i, url in enumerate(urls):

        # build command line for globus-url-copy
        cmd = tool_info['globus-url-copy']['full_path']
        # make output match our current log level
        if logger.isEnabledFor(logging.DEBUG):
            cmd += " -dbg"

        key = "SSH_PRIVATE_KEY_" + url.sitename
        if key in os.environ:
            os.environ["SSH_PRIVATE_KEY"] = os.environ[key]

        cmd += " file:///dev/null"
        cmd += " " + url.url()

        try:
            myexec(cmd, 6*60*60, True)
        except:
            raise RuntimeError("globus-url-copy command failed")


def irods_login(sitename):
    """
    log in to irods by using the iinit command - if the file already exists,
    we are already logged in
    """
    key = "irodsEnvFile_" + sitename
    if key in os.environ:
        os.environ["irodsEnvFile"] = os.environ[key]
    
    f = os.environ['irodsAuthFileName']
    if os.path.exists(f):
        return
    
    # read password from env file
    if not "irodsEnvFile" in os.environ:
        raise RuntimeError("Missing irodsEnvFile - unable to do irods transfers")
    password = None
    h = open(os.environ['irodsEnvFile'], 'r')
    for line in h:
        items = line.split(" ", 2)
        if items[0].lower() == "irodspassword":
            password = items[1].strip(" \t'\"\r\n")
    h.close()
    if password == None:
        raise RuntimeError("No irodsPassword specified in irods env file")
    
    h = open(".irodsAc", "w")
    h.write(password + "\n")
    h.close()
    
    cmd = "cat .irodsAc | iinit"
    myexec(cmd, 6*60*60, True)
        
    os.unlink(".irodsAc")


def irods(urls):
    """
    irods - use the icommands to interact with irods
    """

    if tool_info['irm']['full_path'] == None:
        raise RuntimeError("Unable to do irods transfers becuase iget could" +
                           "not be found in the current path")

    # log in to irods
    try:
        irods_login(urls[0].sitename)
    except Exception, loginErr:
        logger.error(loginErr)
        raise RuntimeError("Unable to log into irods")

    for i, url in enumerate(urls):
        cmd = "irm -f " + url.path
        if options.recursive:
            cmd += " -r"
        try:
            myexec(cmd, 6*60*60, True)
        except Exception, err:
            logger.error(err)
            raise

def srm(urls):
    """
    srm - use lcg-del or srm-rm
    """

    key = "X509_USER_PROXY_" + urls[0].sitename
    if key in os.environ:
        os.environ["X509_USER_PROXY"] = os.environ[key]

    base_cmd = ""
    if tool_info['lcg-del']['full_path'] != None:
        base_cmd = "lcg-del -b -l -D srmv2 --force"
    elif tool_info['srm-rm']['full_path'] != None:
        base_cmd = "srm-rm"
    else:
        raise RuntimeError("Unable to do srm remove becuase lcg-del/srm-rm could not be found")

    for i, url in enumerate(urls):
        cmd = base_cmd + " " + url.url()
        
        try:
            myexec(cmd, 6*60*60, True)
        except Exception, err:
            logger.error(err)
            raise

def s3(urls):
    """
    s3 - uses pegasus-s3 to interact with Amazon S3 
    """

    if tool_info['pegasus-s3']['full_path'] == None:
        raise RuntimeError("Unable to do S3 transfers becuase pegasus-s3 could not be found")

    for i, url in enumerate(urls):

        fixed_url = url.url()
        
        # PM-790: recursive deletes are really a pattern matching. For example,
        # if told to remove the foo/bar directory, we need to translate it into
        # foo/bar/*
        if options.recursive:
            # first make sure there are no trailing slashes
            last_char = len(fixed_url) - 1
            while len(fixed_url) > 0 and last_char > 0 and fixed_url[last_char] == "/":
                fixed_url = fixed_url[0:last_char]
                last_char -= 1
            fixed_url += "/*"
            logger.info("Transformed remote URL to " + fixed_url)
        
        key = "S3CFG_" + url.sitename
        if key in os.environ:
            os.environ["S3CFG"] = os.environ[key]
        
        cmd = "pegasus-s3 rm"
        if options.debug:
            cmd += " -v"
        if options.ignore_failures:
            cmd += " -f"
        cmd += " " + fixed_url

        try:
            myexec(cmd, 6*60*60, True)
        except Exception, err:
            logger.error(err)
            raise


def gs(urls):
    """
    gs - uses gsutil to interact with Google Storage 
    """

    if tool_info['gsutil']['full_path'] == None:
        raise RuntimeError("Unable to do S3 transfers becuase pegasus-s3 could not be found")

    key = "BOTO_CONFIG_" + urls[0].sitename
    if key in os.environ:
        os.environ["BOTO_CONFIG"] = os.environ[key]

    key = "GOOGLE_PKCS12_" + urls[0].sitename
    if key in os.environ:
        os.environ["GOOGLE_PKCS12"] = os.environ[key]
        
    # we need to update the boto config file to specify the full
    # path to the PKCS12 file
    try:
        tmp_fd, tmp_name = tempfile.mkstemp(prefix="pegasus-transfer-", suffix=".lst")
        tmp_file = os.fdopen(tmp_fd, "w+b")
    except:
      raise RuntimeError("Unable to create tmp file for gs boto file")
    try:
        conf = ConfigParser.SafeConfigParser()
        conf.read(os.environ["BOTO_CONFIG"])
        conf.set("Credentials", "gs_service_key_file", os.environ["GOOGLE_PKCS12"])
        conf.write(tmp_file)
        tmp_file.close()
    except Exception, err:
        logger.error(err)
        raise RuntimeError("Unable to convert boto config file")

    for i, url in enumerate(urls):
        
        cmd = "BOTO_CONFIG=%s gsutil rm -a" %(tmp_name)
        if options.ignore_failures or options.recursive:
            cmd += " -f"
        if options.recursive:
            cmd += " -r"  
        cmd += " " + url.url()

        try:
            myexec(cmd, 6*60*60, True)
        except Exception, err:
            # ignore all gsutil errors for now - this will be fixed in a future version
            pass

    os.unlink(tmp_name)
        
        
def urls_groupable(a, b):
    """
    compares two urls, and determins if they are similar enough to be
    grouped together for one tool
    """
    if a.sitename != b.sitename or a.proto != b.proto:
        return False
    return True


def handle_removes(urls):
    """
    removes the file with the given url
    """
    try:
        if tool_map.has_key(urls[0].proto):
            tool = tool_map[urls[0].proto]
            if tool == "rm":
                rm(urls)
            elif tool == "scp":
                scp(urls)
            elif tool == "gsiftp":
                gsiftp(urls)
            elif tool == "sshftp":
                check_tool("globus-url-copy", "-version", "globus-url-copy: ([\.0-9a-zA-Z]+)")
                sshftp(urls)
            elif tool == "irods":
                irods(urls)
            elif tool == "srm":
                srm(urls)
            elif tool == "s3":
                s3(urls)
            elif tool == "gs":
                gs(urls)
            else:
                logger.critical("Error: No mapping for the tool '%s'" %(tool))
                myexit(1)
        else:
            logger.critical("Error: This tool does not know how to remove from %s://" % (url.proto))
            myexit(1)

    except RuntimeError, err:
        raise


def myexit(rc):
    """
    system exit without a stack trace - silly python
    """
    try:
        sys.exit(rc)
    except SystemExit:
        sys.exit(rc)


# --- main ----------------------------------------------------------------------------

# dup stderr onto stdout
sys.stderr = sys.stdout

# Configure command line option parser
prog_usage = "usage: %s [options]" % (prog_base)
parser = optparse.OptionParser(usage=prog_usage)
parser.add_option("-f", "--file", action = "store", dest = "file",
                  help = "File containing URLs to be removed. If not given, list is read from stdin.")
parser.add_option("-i", "--ignore-failures", action = "store_true", dest = "ignore_failures",
                  help = "When specified, the tool ignores any failures and always exits with 0.")
parser.add_option("-r", "--recursive", action = "store_true", dest = "recursive",
                  help = "Enable recursive removal for tools that support it.")
parser.add_option("-d", "--debug", action = "store_true", dest = "debug",
                  help = "Enables debugging ouput.")

# Parse command line options
(options, args) = parser.parse_args()
setup_logger(options.debug)

# Die nicely when asked to (Ctrl+C, system shutdown)
signal.signal(signal.SIGINT, prog_sigint_handler)

# stdin or file input?
if options.file == None:
    logger.info("Reading URLs from stdin")
    input_file = sys.stdin
else:
    logger.info("Reading URLs from %s" % (options.file))
    try:
        input_file = open(options.file, 'r')
    except Exception, err:
        logger.critical('Error reading url list: %s' % (err))
        myexit(1)

# check environment and tools
try:
    check_env_and_tools()
except Exception, err:
    logger.critical(err)
    myexit(1)

# list of work
url_q = deque()

# fill the url queue with user provided entries
line_nr = 0
sitename = ""
try:
    for line in input_file.readlines():
        line_nr += 1
        if len(line) > 3:
            line = line.rstrip('\n')
            if line[0] == '#':
                r = re_parse_comment.search(line)
                if r:
                    sitename =  r.group(1)
                else:
                    logger.critical('Unable to parse comment on line %d'
                                    %(line_nr))
                    myexit(1)
            else:
                url = URL()
                url.set_sitename(sitename)
                url.set_url(line)
                url_q.append(url)
                sitename = ""
except Exception, err:
    logger.critical('Error handling url: %s' % (err))
    myexit(1)

# do the removals
exit_code = 0
while url_q:

    u_main = url_q.popleft()

    # create a list of urls to pass to underlying tool
    u_list = []
    u_list.append(u_main)

    try:
        u_next = url_q[0]
    except IndexError, err:
        u_next = False
    while u_next and urls_groupable(u_main, u_next):
        u_list.append(u_next)
        url_q.popleft()
        try:
            u_next = url_q[0]
        except IndexError, err:
            u_next = False

    # magic!
    try:
        handle_removes(u_list)
    except Exception, err:
        logger.error(err)
        exit_code = 1

# sometimes we just want to ignore the failures
if options.ignore_failures:
    exit_code = 0

myexit(exit_code)


