#!/usr/bin/env python

"""
Pegasus utility for transfer of files during workflow enactment

Usage: pegasus-transfer [options]
"""

##
#  Copyright 2007-2013 University Of Southern California
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing,
#  software distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
##

from collections import deque
import errno
import logging
import math
import optparse
import os
import re
import signal
import stat
import string
import subprocess
import sys
import tempfile
import time


__author__ = "Mats Rynge <rynge@isi.edu>"


# --- regular expressions -----------------------------------------------------

re_parse_url = re.compile(r'([\w]+)://([\w\.\-:@]*)(/[\S]*)?')


# --- classes -----------------------------------------------------------------

class Transfer:
    """
    Represents a single transfer request.
    """

    pair_id        = 0       # the id of the pair in the input (nth pair)
    src_proto      = ""      # 
    src_host       = ""      # 
    src_path       = ""      #
    dst_proto      = ""      #
    dst_host       = ""      #
    dst_path       = ""      #
    allow_grouping = True    # can this transfer be grouped with others?

    def __init__(self, pair_id):
        """
        Initializes the transfer class
        """
        self.pair_id = pair_id

    def set_src(self, url):
        self.src_proto, self.src_host, self.src_path = self.parse_url(url)
    
    def set_dst(self, url):
        self.dst_proto, self.dst_host, self.dst_path = self.parse_url(url)

    def parse_url(self, url):
        proto = ""
        host = ""
        path = ""

        # default protocol is file://
        if string.find(url, ":") == -1:
            logger.debug("URL without protocol (" + url + 
                        ") - assuming file://")
            url = "file://" + url

        # file url is a special cases as it can contain relative paths and
        # env vars
        if string.find(url, "file:") == 0:
            proto = "file"
            # file urls can either start with file://[\w]*/ or file: (no //)
            path = re.sub("^file:(//)?", "", url)
            path = expand_env_vars(path)
            return proto, host, path
        
        # symlink url is a special cases as it can contain relative paths and
        # env vars
        if string.find(url, "symlink:") == 0:
            proto = "symlink"
            # symlink urls can either start with symlink://[\w]*/ or
            # symlink: (no //)
            path = re.sub("^symlink:(//)?", "", url)
            path = expand_env_vars(path)
            return proto, host, path

        # other than file/symlink urls
        r = re_parse_url.search(url)
        if not r:
            raise RuntimeError("Unable to parse URL: %s" % (url))
        
        # Parse successful
        proto = r.group(1)
        host = r.group(2)
        path = r.group(3)
        
        if path == None:
            path = ""
        
        # no double slashes in urls
        path = re.sub('//+', '/', path)
        
        return proto, host, path

    def src_url(self):
        return "%s://%s%s" % (self.src_proto, self.src_host, self.src_path)

    def src_url_srm(self):
        """
        srm-copy is using broken urls - wants an extra / 
        """
        if self.src_proto != "srm":
            return "%s://%s/%s" % (self.src_proto, self.src_host, self.src_path)
        return self.src_url()
    
    def dst_url(self):
        return "%s://%s%s" % (self.dst_proto, self.dst_host, self.dst_path)

    def dst_url_srm(self):
        """
        srm-copy is using broken urls - wants an extra / 
        """
        if self.dst_proto != "srm":
            return "%s://%s/%s" % (self.dst_proto, self.dst_host, self.dst_path)
        return self.dst_url()
    
    def dst_url_dirname(self):
        dn = os.path.dirname(self.dst_path)
        return "%s://%s%s" % (self.dst_proto, self.dst_host, dn)

    def groupable(self):
        """
        currently only gridftp allows for grouping
        """
        return self.allow_grouping and \
               (self.src_proto == "gsiftp" or self.dst_proto == "gsiftp")

    def __cmp__(self, other):
        """
        compares first on protos, then on hosts, then on paths - useful
        for grouping similar types of transfers
        """
        if cmp(self.src_proto, other.src_proto) != 0:
            return cmp(self.src_proto, other.src_proto)
        if cmp(self.dst_proto, other.dst_proto) != 0:
            return cmp(self.dst_proto, other.dst_proto)
        if cmp(self.src_host, other.src_host) != 0:
            return cmp(self.src_host, other.src_host)
        if cmp(self.dst_host, other.dst_host) != 0:
            return cmp(self.dst_host, other.dst_host)
        if cmp(self.src_path, other.src_path) != 0:
            return cmp(self.src_path, other.src_path)
        if cmp(self.dst_path, other.dst_path) != 0:
            return cmp(self.dst_path, other.dst_path)
        return 0


class Singleton(type):
    """Implementation of the singleton pattern"""
    _instances = {}
    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            cls._instances[cls] = \
                super(Singleton, cls).__call__(*args, **kwargs)
        return cls._instances[cls]


class Tools(object):
    """Singleton for detecting and maintaining tools we depend on
    """
    
    __metaclass__ = Singleton
    
    _info = {}

    def find(self, executable, version_arg, version_regex):
        
        if executable in self._info:
            if self._info[executable] == None:
                return None
            return self._info[executable]
        
        logger.info("Trying to detect availability/location of tool: %s"
                    %(executable))
        
        # initialize the global tool info for this executable
        self._info[executable] = {}
        self._info[executable]['full_path'] = None
        self._info[executable]['version'] = None
        self._info[executable]['version_major'] = None
        self._info[executable]['version_minor'] = None
        self._info[executable]['version_patch'] = None
    
        # figure out the full path to the executable
        full_path = backticks("which " + executable + " 2>/dev/null") 
        full_path = full_path.rstrip('\n')
        if full_path == "":
            logger._info("Command '%s' not found in the current environment"
                        %(executable))
            self._info[executable] = None
            return self._info[executable]
        self._info[executable]['full_path'] = full_path
    
        # version
        if version_regex == None:
            version = "N/A"
        else:
            version = backticks(executable + " " + version_arg + " 2>&1")
            version = version.replace('\n', "")
            re_version = re.compile(version_regex)
            result = re_version.search(version)
            if result:
                version = result.group(1)
            self._info[executable]['version'] = version
    
        # if possible, break up version into major, minor, patch
        re_version = re.compile("([0-9]+)\.([0-9]+)(\.([0-9]+)){0,1}")
        result = re_version.search(version)
        if result:
            self._info[executable]['version_major'] = int(result.group(1))
            self._info[executable]['version_minor'] = int(result.group(2))
            self._info[executable]['version_patch'] = result.group(4)
        if self._info[executable]['version_patch'] == None or \
           self._info[executable]['version_patch'] == "":
            self._info[executable]['version_patch'] = None
        else:
            self._info[executable]['version_patch'] = \
                int(self._info[executable]['version_patch'])
    
        logger.info("Tool found: %s   Version: %s   Path: %s" 
                    % (executable, version, full_path))
        return self._info[executable]['full_path']

    def full_path(self, executable):
        """ Returns the full path to a given executable """
        if executable in self._info and self._info[executable] != None:
            return self._info[executable]['full_path']
        return None

    def major_version(self, executable):
        """ Returns the detected major version given executable """
        if executable in self._info and self._info[executable] != None:
            return self._info[executable]['version_major']
        return None
                

class TransferHandlerBase(object):
    """
    Base class for all transfer handlers. Derived classes should set the 
    protocol map (for example ["http->file"]) and implement the do_transfer()
    method.
    """
    
    _name = "BaseHandler"
    _protocol_map = []
    
    def do_transfer(self, transfer, attempt):
        """
        Handles single transfer - all derived classes should override this
        method
        """
        raise RuntimeError("do_transfer() is not implemented in " + self._name)
    
    def do_multi_transfer(self, transfers, attempt):
        """
        Handles transfers for a list of similar transfers. This is used mostly
        for efficiency when using protocols such as gsiftp. Implementing this
        method is optional for protocols
        """
        raise RuntimeError("do_multi_transfer() is not implemented in " +
                           self._name)
    
    def protocol_map_check(self, src_proto, dst_proto):
        """
        Checks to see if a src/dst protocol pair can be handled by the handler.
        This is the base for the automatic handler detection in the TransferSe
        class.
        """
        item = str(src_proto) + "->" + str(dst_proto)
        return (item in self._protocol_map)


class CpHandler(TransferHandlerBase):
    """
    Uses the system cp command to copy local file to file
    """
    
    _name = "CpHandler"
    _protocol_map = ["file->file"]

    def do_transfer(self, transfer, attempt):
        prepare_local_dir(os.path.dirname(transfer.dst_path))
        if os.path.exists(transfer.src_path) \
           and os.path.exists(transfer.dst_path):
            # make sure src and target are not the same file - have to
            # compare at the inode level as paths can differ
            src_inode = os.stat(transfer.src_path)[stat.ST_INO]
            dst_inode = os.stat(transfer.dst_path)[stat.ST_INO]
            if src_inode == dst_inode:
                logger.warning("cp: src (%s) and dst (%s) already exists"
                               % (transfer.src_path, transfer.dst_path))
                return True
        cmd = "/bin/cp -f -L '%s' '%s'" \
            % (transfer.src_path, transfer.dst_path)
        try:
            myexec(cmd, default_subshell_timeout, True)
        except RuntimeError, err:
            logger.error(err)
            return False
        stats_add(transfer.dst_path)
        return True


class FDTHandler(TransferHandlerBase):
    """
    Handler for FDT (Fast Data Transfer - http://monalisa.cern.ch/FDT/)
    """

    _name = "FDTHandler"
    _protocol_map = ["file->fdt", "fdt->file"]
        
    def do_transfer(self, transfer, attempt):

        # download fdt.jar on demand - it can not be shipped with Pegasus due
        # to licensing
        if not os.path.exists("fdt.jar"):
            cmd = "wget -nv -O fdt.jar http://monalisa.cern.ch/FDT/lib/fdt.jar"
            try:
                myexec(cmd, 10*60, True)
            except RuntimeError, err:
                logger.error(err)
                
        cmd = "echo | java -jar fdt.jar"
        if transfer.dst_proto == "file":
            prepare_local_dir(os.path.dirname(transfer.dst_path))
            cmd += " " + transfer.src_host + ":" + transfer.src_path
            cmd += " " + transfer.dst_path
        else:
            cmd += " " + transfer.src_path
            cmd += " " + transfer.dst_host + ":" + transfer.dst_path
            stats_add(transfer.src_path)
        try:
            myexec(cmd, default_subshell_timeout, True)
            if transfer.dst_proto == "file":
                stats_add(transfer.dst_path)    
        except RuntimeError, err:
            logger.error(err)
            return False
        return True
    


class GridFtpHandler(TransferHandlerBase):
    """
    Transfers to/from and between GridFTP servers
    """
    
    _name = "GridFtpHandler"
    _protocol_map = [
                    "file->gsiftp",
                    "gsiftp->file", 
                    "gsiftp->gsiftp",
                    "ftp->ftp",
                    "ftp->gsiftp",
                    "gsiftp->ftp",
                    "http->gsiftp"
                    ]

    def do_transfer(self, transfer, attempt):
        """
        gsiftp - use globus-url-copy for transfers
        """
        
        tools = Tools()
        if tools.find("globus-url-copy", "-version", "([0-9]+\.[0-9]+)") \
           == None:
            raise RuntimeError("Unable to do gsiftp transfers because" +
                               " globus-url-copy could not be found")
            
        third_party = transfer.src_proto == "gsiftp" \
                      and transfer.dst_proto == "gsiftp"

        success = self._exec_transfers([transfer], attempt,  True, third_party)
        return success
    
    
    def do_multi_transfer(self, full_list, attempt, failed_q):
        """
        gsiftp - globus-url-copy for now, maybe uberftp in the future
        """
        if len(full_list) == 0:
            return
        
        tools = Tools()
        if tools.find("globus-url-copy", "-version", "([0-9]+\.[0-9]+)") \
           == None:
            raise RuntimeError("Unable to do gsiftp transfers because" +
                               " globus-url-copy could not be found")
        
        # create lists with similar (same src host/path, same dst host/path)
        # url pairs
        while len(full_list) > 0:
    
            similar_list = []
    
            curr = full_list.pop()
            prev = curr
            third_party = curr.src_proto == "gsiftp" \
                          and curr.dst_proto == "gsiftp"
    
            while self._check_similar(curr, prev):
                
                similar_list.append(curr)
    
                if len(full_list) == 0:
                    break
                else:
                    prev = curr
                    curr = full_list.pop()
    
            if not self._check_similar(curr, prev):
                # the last pair is not part of the set and needs to be added
                # back to the beginning of the list
                full_list.append(curr)
    
            if len(similar_list) == 0:
                break
    
            # we now have a list of similar transfers - break up and send the
            # first one with create dir and the rest with no create dir options
            first_list = []
            first_list.append(similar_list.pop())
            mkdir_done = self._exec_transfers(first_list, attempt, 
                                              True, third_party)
    
            # first attempt get some extra tries - this is to drill down on
            # guc options
            if attempt == 1 and not mkdir_done:
                mkdir_done = self._exec_transfers(first_list, attempt, 
                                                  True, third_party)
                if not mkdir_done:
                    mkdir_done = self._exec_transfers(first_list,
                                                      attempt, True, 
                                                      third_party)
    
            if mkdir_done:
                # run the rest of the group - but limit the number of entries
                # for each pipeline
                chunks = self._split_similar(similar_list)
                for l in chunks:
                    if not self._exec_transfers(l, attempt, False, third_party):
                        for i, t in enumerate(l):
                            failed_q.append(t)
            else:
                # mkdir job failed - all subsequent jobs will fail
                failed_q.append(first_list[0])
                for i, t in enumerate(similar_list):
                    failed_q.append(t)
   
    
    def _exec_transfers(self, transfers, attempt, create_dest, third_party):
        """
        sub to gsiftp() - transfers a list of urls
        """
        global gsiftp_failures
        
        delayed_file_stat = []
    
        # create tmp file with transfer src/dst pairs
        num_pairs = 0
        try:
            tmp_fd, tmp_name = tempfile.mkstemp(prefix="pegasus-transfer-",
                                                suffix=".lst")
            tmp_file = os.fdopen(tmp_fd, "w+b")
        except:
            raise RuntimeError("Unable to create tmp file for"
                               + " globus-url-copy transfers")
            
        for i, t in enumerate(transfers):
            num_pairs += 1
            logger.debug("   adding %s %s" % (t.src_url(), t.dst_url()))
    
            # delay stating until we have finished the transfers
            if t.src_proto == "file":
                delayed_file_stat.append(t.src_path)
            elif t.dst_proto == "file":
                delayed_file_stat.append(t.dst_path)
    
            tmp_file.write("%s %s\n" % (t.src_url(), t.dst_url()))
    
        tmp_file.close()
        
        logger.info("Grouped %d similar gsiftp transfers together in"
                    " temporary file %s" %(num_pairs, tmp_name))
    
        transfer_success = False
        
        # build command line for globus-url-copy
        tools = Tools()
        cmd = tools.full_path('globus-url-copy')
        cmd += self._guc_options(attempt, create_dest, third_party)
        cmd += " -f " + tmp_name
        try:
            myexec(cmd, default_subshell_timeout, True)
            transfer_success = True
        except Exception, err:
            logger.error(err)
            gsiftp_failures += 1
    
        if transfer_success:
            # stat the files
            for i, filename in enumerate(delayed_file_stat): 
                stats_add(filename)
    
        os.unlink(tmp_name)
    
        return transfer_success
   
    def _guc_options(self, attempt, create_dest, third_party):
        """
        determine a set of globus-url-copy options based on how previous
        transfers went
        """
        global gsiftp_failures
    
        tools = Tools()
        options = ""
    
        # make output from guc match our current log level
        if logger.isEnabledFor(logging.DEBUG):
            options += " -verbose"
    
        # should we try to create directories?
        if create_dest:
            options += " -create-dest"
    
        # concurrency is safe for both pull/push and 3rd party
        if tools.major_version('globus-url-copy') >= 5 \
           and gsiftp_failures == 0:
            options += " -concurrency 4"
    
        # Only do third party transfers for gsiftp->gsiftp. For other
        # combinations, fall back to settings which will for well over for
        # example NAT
        if third_party:
    
            # parallism
            options += " -parallel 4"
    
            # -fast should be supported by all servers today
            options += " -fast"
        
            # -pipeline only for 3rd party transfers and Globus 4.2 and above
            # this is experimental so only allow this for the first attempt
            #if (attempt == 1 and \
            #    tool_info['globus-version']['version_major'] == 5 \
            #        or (tool_info['globus-version']['version_major'] >= 4 \
            #            and tool_info['globus-version']['version_minor'] >= 2)):
            #    options += " -pipeline"
        else:
            # gsiftp<->file transfers
            options += " -no-third-party-transfers" \
                     + " -no-data-channel-authentication"
    
        return options
    
    def _check_similar(self, a, b):
        """
        compares two url_pairs, and determins if they are similar enough to be
        grouped together in one transfer input file
        """
        if a.src_host != b.src_host:
            return False
        if a.dst_host != b.dst_host:
            return False
        if os.path.dirname(a.src_path) != os.path.dirname(b.src_path):
            return False
        if os.path.dirname(a.dst_path) != os.path.dirname(b.dst_path):
            return False
        return True

    def _split_similar(self, full_list):
        """
        splits up a long list of similar transfers into smaller
        pieces which can easily be handled by g-u-c
        """
        chunks = []
        size = 1000
        num_chunks = int(math.ceil(len(full_list) / float(size)))
        for i in range(num_chunks):
            start = i * size
            end  = min((i + 1) * size, len(full_list))
            chunks.append(full_list[start:end])
        return chunks
  


class HTTPHandler(TransferHandlerBase):
    """
    pulls from http/https using wget
    """

    _name = "HTTPHandler"
    _protocol_map = ["http->file", "https->file"]

    def do_transfer(self, transfer, attempt):
        
        tools = Tools()
        if tools.find("wget", "--version", "([0-9]+\.[0-9]+)") == None:
            raise RuntimeError("Unable to do http/https transfers becuase" +
                               " wget could not be found")
                               
        # Open Science Grid sites can inform us about local Squid proxies
        if "OSG_SQUID_LOCATION" in os.environ \
           and not "http_proxy" in os.environ:
            os.environ['http_proxy'] = os.environ['OSG_SQUID_LOCATION']

        # but only allow squid caching for the first try - after that go to
        # the source
        if attempt > 1 and "http_proxy" in os.environ:
            logger.info("Disabling HTTP proxy due to previous failures")
            del os.environ['http_proxy']

        prepare_local_dir(os.path.dirname(transfer.dst_path))
        cmd = tools.full_path('wget')
        if logger.isEnabledFor(logging.DEBUG):
            cmd += " -v"
        else:
            cmd += " -q"
        cmd += " --no-check-certificate -O '" + transfer.dst_path + "' '" + \
               transfer.src_url() + "'"
        try:
            myexec(cmd, default_subshell_timeout, True)
            stats_add(transfer.dst_path)
        except RuntimeError, err:
            logger.error(err)
            return False
        return True


class IRodsHandler(TransferHandlerBase):
    """
    Handler for iRods - http://www.irods.org/
    """
    
    _name = "IRodsHandler"
    _protocol_map = ["file->irods", "irods->file"]

    def _irods_login(self):
        """
        log in to irods by using the iinit command - if the file already exists,
        we are already logged in
        """
        f = os.environ['irodsAuthFileName']
        if os.path.exists(f):
            return
    
        # read password from env file
        if not "irodsEnvFile" in os.environ:
            raise RuntimeError("Missing irodsEnvFile - unable to do irods "
                               + " transfers")
        password = None
        h = open(os.environ['irodsEnvFile'], 'r')
        for line in h:
            items = line.split(" ", 2)
            if items[0].lower() == "irodspassword":
                password = items[1].strip(" \t'\"\r\n")
        h.close()
        if password == None:
            raise RuntimeError("No irodsPassword specified in irods env file")
        
        h = open(".irodsAc", "w")
        h.write(password + "\n")
        h.close()
        
        cmd = "cat .irodsAc | iinit"
        myexec(cmd, 5*60, True)
            
        os.unlink(".irodsAc")
    
    
    def do_transfer(self, transfer, attempt):
        """
        irods - use the icommands to interact with irods
        """
    
        tools = Tools()
        if tools.find("iget", "-h", "Version[ \t]+([\.0-9a-zA-Z]+)") == None:
            raise RuntimeError("Unable to do irods transfers becuase iget"
                               + " could not be found in the current path")
    
        # log in to irods
        try:
            self._irods_login()
        except Exception, loginErr:
            logger.error(loginErr)
            raise RuntimeError("Unable to log into irods")
    
        if transfer.dst_proto == "file":
            # irods->file
            prepare_local_dir(os.path.dirname(transfer.dst_path))
            cmd = "iget -f '" + transfer.src_path + "' '" + transfer.dst_path + "'"
        else:
            # file->irods
            cmd = "imkdir -p '" + os.path.dirname(transfer.dst_path) + "'"
            try:
                myexec(cmd, 60*60, True)
            except:
                # ignore errors from the mkdir command
                pass
            cmd = "iput -f '" + transfer.src_path + "' '" +  transfer.dst_path + "'"

        try:
            myexec(cmd, default_subshell_timeout, True)
            # stats      
            if transfer.dst_proto == "file":
                stats_add(transfer.dst_path)
            else:
                stats_add(transfer.src_path)
        except Exception, err:
            logger.error(err)
            return False
        return True



class S3Handler(TransferHandlerBase):
    """
    Handler for S3 and S3 compatible services
    """
    
    _name = "S3Handler"
    _protocol_map = [
                    "file->s3",
                    "file->s3s",
                    "s3->file",
                    "s3s->file", 
                    "s3->s3",
                    "s3->s3s",
                    "s3s->s3",
                    "s3s->s3s"
                    ]

    buckets_created = {}

    def do_transfer(self, transfer, attempt):
    
        tools = Tools()
        if tools.find("pegasus-s3", "help", None) == None:
            raise RuntimeError("Unable to do S3 transfers becuase"
                               + " pegasus-s3 could not be found")

        local_filename = None
        
        # use cp for s3->s3 transfers, and get/put when one end is a file://
        if (transfer.src_proto == "s3" or transfer.src_proto == "s3s") and \
           (transfer.dst_proto == "s3" or transfer.dst_proto == "s3s"):
            # s3 -> s3
            cmd = "pegasus-s3 cp -f -c '%s' '%s'" % (transfer.src_url(),
                                                     transfer.dst_url())
        elif transfer.dst_proto == "file":
            # this is a 'get'
            local_filename = transfer.dst_path
            prepare_local_dir(os.path.dirname(transfer.dst_path))
            cmd = "pegasus-s3 get '%s' '%s'" % (transfer.src_url(),
                                                transfer.dst_path)
        else:
            # this is a 'put'
            local_filename = transfer.src_path
            cmd = "pegasus-s3 put -f -b '%s' '%s'" % (transfer.src_path,
                                                      transfer.dst_url())

        try:
            myexec(cmd, default_subshell_timeout, True)
            if local_filename != None:
                stats_add(local_filename)
        except Exception, err:
            logger.error(err)
            return False
        return True


class SRMHandler(TransferHandlerBase):
    
    _name = "S3Handler"
    _protocol_map = ["srm->file", "file->srm", "gsiftp->srm", "srm->gsiftp"]

    def do_transfer(self, transfer, failed_q):
        """
        srm - lcg-cp is the preferred clienr, srm-copy the backup one
              Is this generic enough? Do we need to handle space tokens?
        """
        
        tools = Tools()
        if tools.find("lcg-cp",
                      "--version", 
                      "lcg_util-([\.0-9a-zA-Z]+)") == None \
           and \
           tools.find("srm-copy",
                      "-version",
                      "srm-copy[ \t]+([\.0-9a-zA-Z]+)") == None:
            raise RuntimeError("Unable to do srm transfers because" 
                               + " lcg-cp/srm-copy could not be found")
    
        if transfer.dst_proto == "file":
            prepare_local_dir(os.path.dirname(transfer.dst_path))
            
        third_party = (transfer.src_proto == "gsiftp" 
                       or transfer.src_proto == "srm") and \
                      (transfer.dst_proto == "gsiftp"
                       or transfer.dst_proto == "srm")
      
        # prefer lcg-cp
        if tools.full_path('lcg-cp') != None:
            cmd = "lcg-cp"
            if logger.isEnabledFor(logging.DEBUG):
                cmd = cmd + " -v"
            cmd = cmd + " -b -D srmv2 '%s' '%s'" \
                  % (transfer.src_url_srm(), transfer.dst_url_srm())
        else:
            cmd = "srm-copy '%s' '%s' -mkdir" \
                  % (transfer.src_url_srm(), transfer.dst_url_srm())
            if third_party:
                cmd = cmd + " -parallelism 4 -3partycopy"
            if not logger.isEnabledFor(logging.DEBUG):
                cmd = cmd + " >/dev/null"
            
        try:
            myexec(cmd, 6*60*60, True)
        except Exception, err:
            logger.error(err)
            return False
        return True
                

class ScpHandler(TransferHandlerBase):
    """
    Uses scp to copy to/from remote hosts
    """

    _name = "ScpHandler"
    _protocol_map = ["scp->file", "file->scp"]

    def do_transfer(self, transfer, attempt):
        global remote_dirs_created
        cmd = "/usr/bin/scp"
        if "SSH_PRIVATE_KEY" in os.environ:
            cmd += " -i " + os.environ['SSH_PRIVATE_KEY']
        cmd += " -q -B -o StrictHostKeyChecking=no"
        try:
            if transfer.dst_proto == "file":
                prepare_local_dir(os.path.dirname(transfer.dst_path))
                cmd += " " + transfer.src_host + ":" + transfer.src_path
                cmd += " " + transfer.dst_path
            else:
                mkdir_key = "scp://" + transfer.dst_host + ":" \
                          + os.path.dirname(transfer.dst_path)
                if not mkdir_key in remote_dirs_created:
                    self._prepare_scp_dir(transfer.dst_host,
                                         os.path.dirname(transfer.dst_path))
                    remote_dirs_created[mkdir_key] = True
                cmd += " '" + transfer.src_path + "'"
                cmd += " '" + transfer.dst_host + ":" + transfer.dst_path + "'"
                stats_add(transfer.src_path)

            myexec(cmd, default_subshell_timeout, True)
            if transfer.dst_proto == "file":
                stats_add(transfer.dst_path)    

        except RuntimeError, err:
            logger.error(err)
            return False
        return True


    def _prepare_scp_dir(self, rhost, rdir):
        """
        makes sure a local path exists before putting files into it
        """
        cmd = "/usr/bin/ssh"
        if "SSH_PRIVATE_KEY" in os.environ:
            cmd += " -i " + os.environ['SSH_PRIVATE_KEY']
        cmd += " -q -o StrictHostKeyChecking=no"
        cmd += " " + rhost + " '/bin/mkdir -p " + rdir + "'"
        myexec(cmd, default_subshell_timeout, True)


class SymlinkHandler(TransferHandlerBase):
    """
    Sets up symlinks - this is often used when data is local, but needs a
    reference in cwd
    """
    
    _name = "SymlinkHandler"
    _protocol_map = ["file->symlink", "symlink->symlink"]

    def do_transfer(self, transfer, attempt):

        prepare_local_dir(os.path.dirname(transfer.dst_path))

        # we do not allow dangling symlinks
        if not os.path.exists(transfer.src_path):
            logger.warning("Symlink source (%s) does not exist"
                           % (transfer.src_path))
            failed_q.append(transfer)
            return True

        if os.path.exists(transfer.src_path) \
           and os.path.exists(transfer.dst_path):
            # make sure src and target are not the same file - have to
            # compare at the inode level as paths can differ
            src_inode = os.stat(transfer.src_path)[stat.ST_INO]
            dst_inode = os.stat(transfer.dst_path)[stat.ST_INO]
            if src_inode == dst_inode:
                logger.warning("symlink: src (%s) and dst (%s) already exists"
                               % (transfer.src_path, transfer.dst_path))
                return True

        cmd = "ln -f -s '%s' '%s'" % (transfer.src_path, transfer.dst_path)
        try:
            myexec(cmd, 60, True)
        except RuntimeError, err:
            logger.error(err)
            return False
        return True


class TransferSet:
    """
    A transfer set is a set of similar transfers, similar in the sense
    that all the transfers have the same source and destination protocols
    """

    _transfers = None
    _available_handlers = []
    _primary_handler = None
    _secondary_handler = None
    _tmp_file = None

    def __init__(self, transfers_l):

        self._transfers = transfers_l

        # load all the handlers - does the order matter?
        self._available_handlers.append( CpHandler() )
        self._available_handlers.append( FDTHandler() )
        self._available_handlers.append( GridFtpHandler() )
        self._available_handlers.append( HTTPHandler() )
        self._available_handlers.append( IRodsHandler() )
        self._available_handlers.append( S3Handler() )
        self._available_handlers.append( SRMHandler() )
        self._available_handlers.append( ScpHandler() )
        self._available_handlers.append( SymlinkHandler() )

        src_proto = transfers_l[0].src_proto
        dst_proto = transfers_l[0].dst_proto

        # can we find one handler which can handle both source
        # and destination protocols directly?
        for h in self._available_handlers:
            if h.protocol_map_check(src_proto, dst_proto):
                self._primary_handler = h
                logger.debug("Selected %s for handling these transfers" 
                             %(h._name))
                return

        # we need to split the transfer from src to local file,
        # and then transfer the local file to the dst
        for h in self._available_handlers:
            if h.protocol_map_check(src_proto, "file"):
                self._primary_handler = h
                break
        for h in self._available_handlers:
            if h.protocol_map_check("file", dst_proto):
                self._secondary_handler = h
                break
        if self._primary_handler == None or self._secondary_handler == None:
            raise RuntimeError("Unable to find handlers for '%s' to '%s'"
                               %(src_proto, dst_proto))

        logger.debug("Selected %s and %s for handling these transfers"
                     %(self._primary_handler._name,
                       self._secondary_handler._name))


    def do_transfers(self, failed_q, attempt):
        """
        given a list of transfers, figure out what handlers are needed
        and then execute the transfers
        """

        self._tmp_name = None
        if self._secondary_handler != None:
            # we have a two stage transfer to deal with and we need a temp file
            self._tmp_fd, self._tmp_name = \
                tempfile.mkstemp(prefix="pegasus-transfer-two-stage-",
                                 suffix=".data")
            # need to open the permission up to make sure files downstream
            # get sane permissions to inherit
            os.chmod(self._tmp_name, 0644)
            logger.debug("Using temporary file %s for transfers" 
                         %(self._tmp_name))

        # gsiftp multi transfer case
        if (self._secondary_handler == None) and \
           (self._transfers[0].src_proto == "gsiftp" \
            or self._transfers[0].dst_proto == "gsiftp"):
            try:
                success =  self._primary_handler.do_multi_transfer(
                                                    self._transfers,
                                                    attempt,
                                                    failed_q)
            except Exception, e:
                if logger.isEnabledFor(logging.DEBUG):
                    logger.exception("Exception while doing transfer:")
                else:
                    logger.error(e)
            return
        
        # standard src->dst single transfer case
        for i, transfer in enumerate(self._transfers):
            
            # We are being extra careful to detect failures here. We are 
            # considering both a False being returned or an exception being
            # thrown as a failed transfer
            success = False
            if self._secondary_handler == None:
                # one handler to rule them all!
                try:
                    success = self._primary_handler.do_transfer(transfer,
                                                                attempt)
                except Exception, e:
                    if logger.isEnabledFor(logging.DEBUG):
                        logger.exception("Exception while doing transfer:")
                    else:
                        logger.error(e)
            else:
                # break up the transfer into two, but keep a handle to the main
                # transfer as that is the one which will have to go back to the
                # failed queue in case of failure
                t_one = Transfer(transfer.pair_id)
                t_one.set_src(transfer.src_url())
                t_one.set_dst("file://" + self._tmp_name)
                t_two =  Transfer(transfer.pair_id)
                t_two.set_src("file://" + self._tmp_name)
                t_two.set_dst(transfer.dst_url())
                try:
                    success = self._primary_handler.do_transfer(t_one,
                                                                attempt) \
                              and \
                              self._secondary_handler.do_transfer(t_two,
                                                                  attempt)
                except Exception, e:
                    if logger.isEnabledFor(logging.DEBUG):
                        logger.exception("Exception while doing transfer:")
                    else:
                        logger.error(e)

            if success == False:
                failed_q.append(transfer)
        
        # remove temp file
        if self._tmp_name is not None:
            logger.debug("Removing temporary file %s" %(self._tmp_name))
            try:
                os.unlink(self._tmp_name)
            except:
                pass


class Alarm(Exception):
    pass


# --- global variables ----------------------------------------------------------------

prog_dir  = os.path.normpath(os.path.join(os.path.dirname(sys.argv[0])))
prog_base = os.path.split(sys.argv[0])[1]   # Name of this program

logger = logging.getLogger("my_logger")

# timeout for when shelling out
default_subshell_timeout = 6 * 60 * 60;

# track remote directories created so that don't have to
# try to create them over and over again
remote_dirs_created = {}

# gsiftp failure count - used to provide sane globus-url-copy options
gsiftp_failures = 0

# stats
stats_start = 0
stats_end = 0
stats_total_bytes = 0


# --- functions ----------------------------------------------------------------


def setup_logger(debug_flag):
    
    # log to the console
    console = logging.StreamHandler()
    
    # default log level - make logger/console match
    logger.setLevel(logging.INFO)
    console.setLevel(logging.INFO)

    # debug - from command line
    if debug_flag:
        logger.setLevel(logging.DEBUG)
        console.setLevel(logging.DEBUG)

    # formatter
    formatter = logging.Formatter("%(asctime)s %(levelname)7s:  %(message)s")
    console.setFormatter(formatter)
    logger.addHandler(console)
    logger.debug("Logger has been configured")


def prog_sigint_handler(signum, frame):
    logger.warn("Exiting due to signal %d" % (signum))
    myexit(1)


def alarm_handler(signum, frame):
    raise Alarm


def expand_env_vars(s):
    re_env_var = re.compile(r'\${?([a-zA-Z][a-zA-Z0-9_]+)}?')
    s = re.sub(re_env_var, get_env_var, s)
    return s


def get_env_var(match):
    name = match.group(1)
    value = ""
    logger.debug("Looking up " + name)
    if name in os.environ:
        value = os.environ[name]
    return value


def myexec(cmd_line, timeout_secs, should_log):
    """
    executes shell commands with the ability to time out if the command hangs
    """
    global delay_exit_code
    if should_log or logger.isEnabledFor(logging.DEBUG):
        logger.info(cmd_line)
    sys.stdout.flush()

    # set up signal handler for timeout
    signal.signal(signal.SIGALRM, alarm_handler)
    signal.alarm(timeout_secs)

    p = subprocess.Popen(cmd_line, shell=True)
    try:
        stdoutdata, stderrdata = p.communicate()
    except Alarm:
        if sys.version_info >= (2, 6):
            p.terminate()
        raise RuntimeError("Command '%s' timed out after %s seconds"
                           % (cmd_line, timeout_secs))
    rc = p.returncode
    if rc != 0:
        raise RuntimeError("Command '%s' failed with error code %s"
                           % (cmd_line, rc))


def backticks(cmd_line):
    """
    what would a python program be without some perl love?
    """
    return subprocess.Popen(cmd_line, shell=True,
                            stdout=subprocess.PIPE).communicate()[0]


def env_setup():
    
    # PATH setup
    path = "/usr/bin:/bin"
    if "PATH" in os.environ:
        path = os.environ['PATH']
    path_entries = path.split(':')
    
    # is /usr/bin in the path?
    if not("/usr/bin" in path_entries):
        path_entries.append("/usr/bin")
        path_entries.append("/bin")

    # fink on macos x
    if os.path.exists("/sw/bin") and not("/sw/bin" in path_entries):
        path_entries.append("/sw/bin")
       
    # need LD_LIBRARY_PATH for Globus tools
    ld_library_path = ""
    if "LD_LIBRARY_PATH" in os.environ:
        ld_library_path = os.environ['LD_LIBRARY_PATH']
    ld_library_path_entries = ld_library_path.split(':')
    
    # if PEGASUS_HOME is set, prepend it to the PATH (we want it early to
    # override other cruft)
    if "PEGASUS_HOME" in os.environ:
        try:
            path_entries.remove(os.environ['PEGASUS_HOME'] + "/bin")
        except Exception:
            pass
        path_entries.insert(0, os.environ['PEGASUS_HOME'] + "/bin")
    
    # if GLOBUS_LOCATION is set, prepend it to the PATH and LD_LIBRARY_PATH 
    # (we want it early to override other cruft)
    if "GLOBUS_LOCATION" in os.environ:
        try:
            path_entries.remove(os.environ['GLOBUS_LOCATION'] + "/bin")
        except Exception:
            pass
        path_entries.insert(0, os.environ['GLOBUS_LOCATION'] + "/bin")
        try:
            ld_library_path_entries.remove(
                os.environ['GLOBUS_LOCATION'] + "/lib")
        except Exception:
            pass
        ld_library_path_entries.insert(0, 
                                       os.environ['GLOBUS_LOCATION'] + "/lib")

    os.environ['PATH'] = ":".join(path_entries)
    os.environ['LD_LIBRARY_PATH'] = ":".join(ld_library_path_entries)
    os.environ['DYLD_LIBRARY_PATH'] = ":".join(ld_library_path_entries)
    logger.info("PATH=" + os.environ['PATH'])
    logger.info("LD_LIBRARY_PATH=" + os.environ['LD_LIBRARY_PATH'])
    
    # irods requires a password hash file
    os.environ['irodsAuthFileName'] = os.getcwd() + "/.irodsA"



def prepare_local_dir(path):
    """
    makes sure a local path exists before putting files into it
    """
    if not(os.path.exists(path)):
        logger.debug("Creating local directory " + path)
        try:
            os.makedirs(path, 0755)
        except os.error, err:
            # if dir already exists, ignore the error
            if not(os.path.isdir(path)):
                raise RuntimeError(err)


def transfers_groupable(a, b):
    """
    compares two url_pairs, and determins if they are similar enough to be
    grouped together for one tool
    """
    if not a.groupable() or not b.groupable():
        return False
    if a.src_proto != b.src_proto:
        return False
    if a.dst_proto != b.dst_proto:
        return False
    return True


def stats_add(filename):
    global stats_total_bytes
    try:
        s = os.stat(filename)
        stats_total_bytes = stats_total_bytes + s[stat.ST_SIZE]
    except Exception, err:
        pass # ignore


def stats_summarize():
    if stats_total_bytes == 0:
        logger.info("Stats: no local files in the transfer set")
        return

    total_secs = stats_end - stats_start
    Bps = stats_total_bytes / total_secs

    logger.info("Stats: %sB transferred in %.0f seconds. Rate: %sB/s (%sb/s)" \
                % (iso_prefix_formatted(stats_total_bytes), total_secs, 
                   iso_prefix_formatted(Bps), iso_prefix_formatted(Bps*8)))
    logger.info("NOTE: stats do not include third party gsiftp/srm transfers")


def iso_prefix_formatted(n):
    prefix = ""
    n = float(n)
    if n > (1024*1024*1024*1024):
        prefix = "T"
        n = n / (1024*1024*1024*1024)
    elif n > (1024*1024*1024):
        prefix = "G"
        n = n / (1024*1024*1024)
    elif n > (1024*1024):
        prefix = "M"
        n = n / (1024*1024)
    elif n > (1024):
        prefix = "K"
        n = n / (1024)
    return "%.1f %s" % (n, prefix)


def myexit(rc):
    """
    system exit without a stack trace - silly python
    """
    try:
        sys.exit(rc)
    except SystemExit:
        sys.exit(rc)


# --- main ----------------------------------------------------------------------------

def main():
    global stats_start
    global stats_end
    
    # dup stderr onto stdout
    sys.stderr = sys.stdout
    
    # Configure command line option parser
    prog_usage = "usage: %s [options]" % (prog_base)
    parser = optparse.OptionParser(usage=prog_usage)
    
    parser.add_option("-f", "--file", action = "store", dest = "file",
                      help = "File containing URL pairs to be transferred." +
                             " If not given, list is read from stdin.")
    parser.add_option("", "--max-attempts", action = "store", type="int",
                      dest = "max_attempts", default = 3,
                      help = "Number of attempts allowed for each transfer." +
                             " Default is 3.")
    parser.add_option("-d", "--debug", action = "store_true", dest = "debug",
                      help = "Enables debugging ouput.")
    
    # Parse command line options
    (options, args) = parser.parse_args()
    setup_logger(options.debug)
    
    # Die nicely when asked to (Ctrl+C, system shutdown)
    signal.signal(signal.SIGINT, prog_sigint_handler)
    
    attempts_max = options.max_attempts
    
    # stdin or file input?
    if options.file == None:
        logger.info("Reading URL pairs from stdin")
        input_file = sys.stdin
    else:
        logger.info("Reading URL pairs from %s" % (options.file))
        try:
            input_file = open(options.file, 'r')
        except Exception, err:
            logger.critical('Error reading url pair list: %s' % (err))
            myexit(1)
    
    # check environment
    try:
        env_setup()
    except Exception, err:
        logger.critical(err)
        myexit(1)
    
    # queues to track the work
    transfer_q = deque()
    failed_q = deque()
    
    # fill the transfer queue with user provided entries
    line_nr = 0
    pair_nr = 0
    inputs = []
    url_first = True
    try:
        for line in input_file.readlines():
            line_nr += 1
            if line[0] != '#' and len(line) > 4:
                line = line.rstrip('\n')
                if url_first:
                    pair_nr += 1
                    url_pair = Transfer(pair_nr)
                    url_pair.set_src(line)
                    url_first = False
                else:
                    url_pair.set_dst(line)
                    inputs.append(url_pair)
                    url_first = True
    except Exception, err:
        logger.critical('Error reading url pair list: %s' % (err))
        myexit(1)
    
    # we will now sort the list as some tools (gridftp) can optimize when
    # given a group of similar transfers
    logger.info("Sorting the tranfers based on transfer type and source/destination")
    inputs.sort()
    
    transfer_q = deque(inputs)
    
    # start the stats time
    stats_start = time.time()
    
    # attempt transfers until the queue is empty
    done = False
    attempt_current = 0
    while not done:
    
        attempt_current = attempt_current + 1
        logger.info('-' * 80)
        logger.info("Starting transfers - attempt %d" % (attempt_current))
    
        # do the transfers
        while transfer_q:
            
            t_main = transfer_q.popleft()
            
            # create a list of transfers to pass to underlying tool
            t_list = []
            t_list.append(t_main)
    
            try:
                t_next = transfer_q[0]
            except IndexError, err:
                t_next = False
            while t_next and transfers_groupable(t_main, t_next):
                t_list.append(t_next)
                transfer_q.popleft()
                try:
                    t_next = transfer_q[0]
                except IndexError, err:
                    t_next = False
    
            # magic!
            ts = TransferSet(t_list)
            ts.do_transfers(failed_q, attempt_current)
    
            logger.debug("%d items in failed_q" %(len(failed_q)))
        
        # are we done?
        if attempt_current == attempts_max or not failed_q:
            done = True
            break
        
        # retry failed transfers with a delay
        if failed_q and attempt_current < attempts_max:
            time.sleep(10) # do not sleep too long - we want to give quick
                            # feedback on failures to the workflow
        while failed_q:
            t = failed_q.popleft()
            t.allow_grouping = False # only allow grouping on the first try
            transfer_q.append(t)
    
    logger.info('-' * 80)
    
    # end the stats timer and show summary
    stats_end = time.time()
    stats_summarize()
    
    if failed_q:
        logger.critical("Some transfers failed! See above, and possibly stderr.")
        myexit(1)
    
    logger.info("All transfers completed successfully.")
    
    myexit(0)


if __name__ == "__main__":
    main()
    

