#!/usr/bin/env python
import sys
import os
import xml.sax.handler
import xml.sax
from optparse import OptionParser

COLORS = [
    "#1b9e77",
    "#d95f02",
    "#7570b3",
    "#e7298a",
    "#66a61e",
    "#e6ab02",
    "#a6761d",
    "#666666",
    "#8dd3c7",
    "#bebada",
    "#fb8072",
    "#80b1d3",
    "#fdb462",
    "#b3de69",
    "#fccde5",
    "#d9d9d9",
    "#bc80bd",
    "#ccebc5",
    "#ffed6f",
    "#ffffb3"
]

class DAG:
    def __init__(self):
        self.jobs = {}
        self.files = {}

class Node:
    def __init__(self):
        self.id = None
        self.label = None
        self.level = 0
        self.parents = []
        self.children = []

class Job(Node):
    def __init__(self):
        Node.__init__(self)
        self.xform = None
        self.inputs = []
        self.outputs = []

class File(Node):
    def __init__(self):
        Node.__init__(self)

class DAXHandler(xml.sax.handler.ContentHandler):
    """
    This is a DAX file parser
    """
    def __init__(self):
        self.last_job = None
        self.dag = DAG()
        self.jobs = self.dag.jobs
        self.files = self.dag.files
        
    def startElement(self, name, attrs):
        if name in ["job","dax","dag"]:
            job = Job()
        
            job.id = attrs.get("id")
            if job.id is None:
                raise Exception("Invalid DAX: attribute 'id' missing")
            
            if name == "job":
                job.xform = attrs.get("name")
                if job.xform is None:
                    raise Exception("Invalid DAX: job name missing for job %s" % job.id)
                ns = attrs.get("namespace")
                version = attrs.get("version")
                if ns is not None:
                    job.xform = ns + "::" + job.xform
                if version is not None:
                    job.xform = job.xform + ":" + version
            elif name == "dax":
                job.xform = "pegasus::dax"
            else:
                job.xform = "pegasus::dag"
            
            job.label = attrs.get("node-label")
            if job.label is None:
                job.label = attrs.get("file")
                if job.label is None:
                    job.label = job.xform
            
            self.jobs[job.id] = job
            self.last_job = job
        elif name == "uses":
            namespace = attrs.get("namespace")
            version = attrs.get("version")
            filename = attrs.get("name") or attrs.get("file")

            if filename is None:
                raise Exception("name attribute missing on <uses>")

            if namespace is not None:
                filename = namespace + "::" + filename

            if version is not None:
                filename = filename + ":" + version
            
            if filename in self.files:
                f = self.files[filename]
            else:
                f = File()
                f.id = f.label = filename
                self.files[filename] = f

            link = attrs.get("link") or "input"
            link = link.lower()
            if link == "input":
                self.last_job.inputs.append(f)
            elif link == "output":
                self.last_job.outputs.append(f)
            elif link == "inout":
                self.last_job.inputs.append(f)
                self.last_job.outputs.append(f)
            elif link == "none":
                pass
            else:
                raise Exception("Unrecognized link value: %s" % link)
        elif name == "child":
            self.lastchild = attrs.get("ref")
        elif name == "parent":
            if self.lastchild is None:
                raise Exception("Invalid DAX: <parent> outside <child>")
            pid = attrs.get("ref")
            child = self.jobs[self.lastchild]
            parent = self.jobs[pid]
            child.parents.append(parent)
            parent.children.append(child)
            
    def endElement(self, name):
        if name == "child":
            self.lastchild = None

def parse_daxfile(fname):
    """
    Parse DAG from a Pegasus DAX file.
    """
    handler = DAXHandler()
    parser = xml.sax.make_parser()
    parser.setContentHandler(handler)
    f = open(fname,"r")
    parser.parse(f)
    f.close()
    return handler.dag
    
def parse_xform_name(path):
    """
    Parse the transformation name from a submit script. Usually the
    transformation is in a special classad called '+pegasus_wf_xformation'.
    For special pegasus jobs (create_dir, etc.) set the name manually.
    """
    # Handle special cases
    fname = os.path.basename(path)
    if fname.startswith("create_dir_"): return "pegasus::create_dir"
    if fname.startswith("stage_in_"): return "pegasus::stage_in"
    if fname.startswith("stage_out_"): return "pegasus::stage_out"
    if fname.startswith("stage_inter_"): return "pegasus::stage_inter"
    if fname.startswith("stage_worker_"): return "pegasus::stage_worker"
    if fname.startswith("register_"): return "pegasus::register"
    if fname.startswith("clean_up_"): return "pegasus::clean_up"
    
    # Get it from the submit file
    if os.path.isfile(path):
        f = open(path,'r')
        for line in f.readlines():
            if '+pegasus_wf_xformation' in line:
                return line.split('"')[1]
    
    # Otherwise, guess the xform by stripping digits from the name
    name = fname.replace(".sub","")
    return "".join(c for c in name if not '0' <= c <= '9')
    
def parse_dagfile(fname):
    """
    Parse a DAG from a dagfile.
    """
    dagdir = os.path.dirname(fname)
    dag = DAG()
    jobs = dag.jobs
    lastchild = None
    f = open(fname,'r')
    for line in f.readlines():
        line = line.strip()
        if line.startswith("JOB"):
            rec = line.split()
            job = Job()
            if len(rec) < 3:
                raise Exception("Invalid line:",line)
            job.id = rec[1] # Job id
            subfile = rec[2] # submit script
            if not os.path.isabs(subfile):
                subfile = os.path.join(dagdir,subfile)
            job.xform = parse_xform_name(subfile)
            job.label = job.id
            jobs[job.id] = job
        elif line.startswith("PARENT"):
            rec = line.split()
            if len(rec) < 4:
                raise Exception("Invalid line:",line)
            p = jobs[rec[1]]
            c = jobs[rec[3]]
            p.children.append(c)
            c.parents.append(p)
    f.close()
    
    return dag

def remove_xforms(dag, xforms):
    """
    Remove transformations in the DAG by name
    """
    jobs = dag.jobs
    if len(xforms) == 0:
        return
    for id in jobs.keys():
        job = jobs[id]
        if job.xform in xforms:
            print "Removing %s" % job.id
            for p in job.parents:
                p.children.remove(job)
            for c in job.children:
                c.parents.remove(job)
            del jobs[id]
            
def inv_reachable(a, b):
    """
    Is a reachable from b using reverse edges? Reverse edges are
    used because it is a little more efficient than forward edges
    assuming that a node is more likely to have children than
    parents. Does a BFS using the child->parent edges instead of the
    parent->child edges.
    """
    fifo = [a]
    while len(fifo) > 0:
        n = fifo.pop()
        for p in n.parents:
            if p == b: return True
            fifo.append(p)
    return False
            
def simplify(dag):
    """
    Simplify a DAG by removing redundant edges. Redundant edges are edges
    that go from a grandparent to a grandchild. In other words, they are
    edges that, if removed, do not change the dependencies in the workflow.
    We want to remove these because they clutter up the diagram and make
    it hard to read.
    """
    # Find roots
    roots = []
    for id in dag.jobs:
        j = dag.jobs[id]
        if len(j.parents) == 0:
            roots.append(j)
    
    # Assign surrogate root
    root = Job()
    root.level = 0
    for j in roots:
        root.children.append(j)
        
    # Label all levels of the workflow (BFS)
    fifo = [root]
    while len(fifo) > 0:
        n = fifo.pop()
        for c in n.children:
            fifo.append(c)
            c.level = max(c.level, n.level + 1)

    # Eliminate any redundant edges (BFS)
    fifo = [root]
    while len(fifo) > 0:
        n = fifo.pop()
        children = n.children[:]
        for c in children:
            fifo.append(c)
            dist = c.level - n.level
            if dist > 1:
                c.parents.remove(n)
                if inv_reachable(c, n):
                    sys.stderr.write(
                        "Removing redunant edge: %s -> %s\n" % 
                        (n.id, c.id))
                    n.children.remove(c)
                else:
                    c.parents.append(n)
        
    return dag
    
def emit_dot(dag, use_xforms=False, outfile="/dev/stdout", width=8.0, height=10.0, files=False):
    """
    Write a DOT-formatted diagram.
    use_xforms: Use transformation names instead of job names
    outfile: The file name to write the diagam out to.
    width: The width of the diagram
    height: The height of the diagram
    files: Should files be used as edges? True/False
    """
    next_color = 0  # Keep track of next color
    xforms = {} # Keep track of transformation names to assign colors
    
    out = open(outfile,'w')

    out.write("""digraph dag {
    size="%s,%s"
    ratio=fill
    node [style=filled,color="#444444",fillcolor="#ffed6f"]
    edge [arrowhead=normal,arrowsize=1.0]
    \n""" % (width,height))
    
    for id in dag.jobs:
        j = dag.jobs[id]
        if use_xforms:
            label = j.xform
        else:
            label = j.label
        if j.xform not in xforms:
            xforms[j.xform] = next_color
            next_color += 1
            # Just in case we run out of colors
            next_color = min(len(COLORS)-1, next_color)
        color = xforms[j.xform]
        out.write('\t"%s" [shape=ellipse,fillcolor="%s",label="%s"]\n' % (j.id,COLORS[color],label))
    
    out.write('\n')

    if files:
        for id in dag.files:
            f = dag.files[id]
            out.write('\t"%s" [shape=rect,label="%s"]\n' % (f.id,f.label))
    
    out.write('\n')

    for id in dag.jobs:
        j = dag.jobs[id]
        if files:
            for i in j.inputs:
                out.write('\t"%s" -> "%s"\n' % (i.id,j.id))
            for o in j.outputs:
                out.write('\t"%s" -> "%s"\n' % (j.id,o.id))
        else:
            for c in j.children:
                out.write('\t"%s" -> "%s"\n' % (j.id,c.id))
     
    out.write("}\n")
    out.close()
    
def main():
    usage = "%prog [options] DAGFILE"
    description = """Parses DAGFILE and generates a DOT-formatted
graphical representation of the DAG. DAGFILE can be a Condor
DAGMan file, or a Pegasus DAX file."""
    parser = OptionParser(usage=usage,description=description)
    parser.add_option("-s", "--nosimplify", action="store_false",
        dest="simplify", default=True,
        help="Do not simplify the graph by removing redundant edges. [default: False]")
    parser.add_option("-x", "--xforms", action="store_true", 
        dest="xforms", default=False,
        help="Use transformation names as labels instead of the default label")
    parser.add_option("-o", "--output", action="store",
        dest="outfile", metavar="FILE", default="/dev/stdout",
        help="Write output to FILE [default: stdout]")
    parser.add_option("-r", "--remove", action="append",
        dest="remove", metavar="XFORM", default=[],
        help="Remove jobs from the workflow by transformation name")
    parser.add_option("-W", "--width", action="store",
        dest="width", default="8.0",
        help="Width of the digraph [default: 8.0]")
    parser.add_option("-H", "--height", action="store",
        dest="height", default="10.0",
        help="Height of the digraph [default: 10.0]")
    parser.add_option("-f", "--files", action="store_true",
        dest="files", default=False,
        help="Include files. This option is only valid for DAX files. [default: false]")

    (options, args) = parser.parse_args()
    
    if len(args) < 1:
        parser.error("Please specify DAGFILE")
        
    if len(args) > 1:
        parser.error("Invalid argument")
    
    dagfile = args[0]
    if dagfile.endswith(".dag"):
        options.files = False # Cannot use files if DAG
        dag = parse_dagfile(dagfile)
    else:
        dag = parse_daxfile(dagfile)
    
    remove_xforms(dag, options.remove)
     
    # Only simplify if the user did not want to use files
    if options.simplify and not options.files:
        dag = simplify(dag)
    
    emit_dot(dag, options.xforms, options.outfile, 
            options.width, options.height, options.files)
    
if __name__ == '__main__':
    main()
