import sys
import re
from ast import *
from lists import *




#-------------------------------------------------------------
# global parameters
#-------------------------------------------------------------

DEBUG = False
# sys.setrecursionlimit(10000)


MOVE_RATIO     = 0.2
MOVE_SIZE      = 10
MOVE_ROUND     = 5

FRAME_DEPTH    = 1
FRAME_SIZE     = 20

NAME_PENALTY   = 1
IF_PENALTY     = 1
ASSIGN_PENALTY = 1


#-------------------------------------------------------------
# utilities
#-------------------------------------------------------------

IS = isinstance


def debug(*args):
    if DEBUG:
        print args


def dot():
    sys.stdout.write('.')


def isAlpha(c):
    return (c == '_'
            or ('0' <= c <= '9')
            or ('a' <= c <= 'z')
            or ('A' <= c <= 'Z'))


def div(m, n):
    if n == 0:
        return m
    else:
        return m/float(n)


# for debugging
def ps(s):
    v = parse(s).body[0]
    if IS(v, Expr):
        return v.value
    else:
        return v


def sz(s):
    return nodeSize(parse(s), True) - 1


def dp(s):
    return dump(parse(s))


def run(name, closure=True, debug=False):
    fullname1 = name + '1.py'
    fullname2 = name + '2.py'

    global DEBUG
    olddebug = DEBUG
    DEBUG = debug

    diff(fullname1, fullname2, closure)

    DEBUG = olddebug


def demo():
    run('demo')


def go():
    run('heavy')


def pf():
    import cProfile
    cProfile.run("run('heavy')", sort="cumulative")




#------------------------ file system support -----------------------

def pyFileName(filename):
    try:
        start = filename.rindex('/') + 1
    except ValueError:
        start = 0
    end = filename.rindex('.py')
    return filename[start:end]


## file system support
def parseFile(filename):
    f = open(filename, 'r');
    lines = f.read()
    ast = parse(lines)
    improveAST(ast, lines, filename, 'left')
    return ast




#-------------------------------------------------------------
#            tests and operations on AST nodes
#-------------------------------------------------------------

# get list of fields from a node
def nodeFields(node):
    ret = []
    for field in node._fields:
        if field <> 'ctx' and hasattr(node, field):
            ret.append(getattr(node, field))
    return ret



# get full source text where the node is from
def nodeSource(node):
    if hasattr(node, 'nodeSource'):
        return node.nodeSource
    else:
        return None



# utility for getting exact source code part of the node
def src(node):
    return node.nodeSource[node.nodeStart : node.nodeEnd]



def nodeStart(node):
    if (hasattr(node, 'nodeStart')):
        return node.nodeStart
    else:
        return 0



def nodeEnd(node):
    return node.nodeEnd



def isAtom(x):
    return type(x) in [int, str, bool, float]



def isDef(node):
    return IS(node, FunctionDef) or IS(node, ClassDef)



# whether a node is a "frame" which can contain others and be
# labeled
def isFrame(node):
    return type(node) in [ClassDef, FunctionDef, Import, ImportFrom]



def isEmptyContainer(node):

    if IS(node, List) and node.elts == []:
        return True
    if IS(node, Tuple) and node.elts == []:
        return True
    if IS(node, Dict) and node.keys == []:
        return True

    return False


def sameDef(node1, node2):
    if IS(node1, FunctionDef) and IS(node2, FunctionDef):
        return node1.name == node2.name
    elif IS(node1, ClassDef) and IS(node2, ClassDef):
        return node1.name == node2.name
    else:
        return False


def differentDef(node1, node2):
    if isDef(node1) and isDef(node2):
        return node1.name <> node2.name
    return False


# decide whether it is reasonable to consider two nodes to be
# moves of each other
def canMove(node1, node2, c):
    return (sameDef(node1, node2) or
            c <= (nodeSize(node1) + nodeSize(node2)) * MOVE_RATIO)


# whether the node is considered deleted or inserted because
# the other party matches a substructure of it.
def nodeFramed(node, changes):
    for c in changes:
        if (c.isFrame and (node == c.orig or node == c.cur)):
            return True
    return False



# helper for turning nested if statements into sequences,
# otherwise we will be trapped in the nested structure and find
# too many differences
def serializeIf(node):
    if IS(node, If):
        if not hasattr(node, 'nodeEnd'):
            print "has no end:", node

        newif = If(node.test, node.body, [])
        newif.lineno = node.lineno
        newif.col_offset = node.col_offset
        newif.nodeStart = node.nodeStart
        newif.nodeEnd = node.nodeEnd
        newif.nodeSource = node.nodeSource
        newif.fileName = node.fileName
        return [newif] + serializeIf(node.orelse)
    elif IS(node, list):
        ret = []
        for n in node:
            ret += serializeIf(n)
        return ret
    else:
        return [node]


def nodeName(node):
    if IS(node, Name):
        return node.id
    elif IS(node, FunctionDef) or IS(node, ClassDef):
        return node.name
    else:
        return None


def attr2str(node):
    if IS(node, Attribute):
        vName = attr2str(node.value)
        if vName <> None:
            return vName + "." + node.attr
        else:
            return None
    elif IS(node, Name):
        return node.id
    else:
        return None


### utility for counting size of terms
def nodeSize(node, test=False):

    if not test and hasattr(node, 'nodeSize'):
        ret = node.nodeSize

    elif IS(node, list):
        ret = sum(map(lambda x: nodeSize(x, test), node))

    elif isAtom(node):
        ret = 1

    elif IS(node, Name):
        ret = 1

    elif IS(node, Num):
        ret = 1

    elif IS(node, Str):
        ret = 1

    elif IS(node, Expr):
        ret = nodeSize(node.value, test)

    elif IS(node, AST):
        ret = 1 + sum(map(lambda x: nodeSize(x, test), nodeFields(node)))

    else:
        ret = 0

    if test:
        print "node:", node, "size=", ret

    if IS(node, AST):
        node.nodeSize = ret

    return ret




#------------------------------- types ------------------------------
# global storage of running stats
class Stat:
    def __init__(self):
        pass

stat = Stat()



# The difference between nodes are stored as a Change structure.
class Change:
    def __init__(self, orig, cur, cost, isFrame=False):
        self.orig = orig
        self.cur = cur
        if orig == None:
            self.cost = nodeSize(cur)
        elif cur == None:
            self.cost = nodeSize(orig)
        elif cost == 'all':
            self.cost = nodeSize(orig) + nodeSize(cur)
        else:
            self.cost = cost
        self.isFrame = isFrame
    def __repr__(self):
        fr = "F" if self.isFrame else "-"
        def hole(x):
            if x == None:
                return "[]"
            else:
                return x
        return ("(C:" + str(hole(self.orig)) + ":" + str(hole(self.cur))
                + ":" + str(self.cost) + ":" + str(self.similarity())
                + ":" + fr + ")")
    def similarity(self):
        total = nodeSize(self.orig) + nodeSize(self.cur)
        return 1 - div(self.cost, total)



# Three major kinds of changes:
# * modification
# * deletion
# *insertion
def modifyNode(node1, node2, cost):
    return loner(Change(node1, node2, cost))

def delNode(node):
    return loner(Change(node, None, nodeSize(node)))

def insNode(node):
    return loner(Change(None, node, nodeSize(node)))



# general cache table for acceleration
class Cache:
    def __init__(self):
        self.table = {}
    def __repr__(self):
        return "Cache:" + str(self.table)
    def __len__(self):
        return len(self.table)
    def put(self, key, value):
        self.table[key] = value
    def get(self, key):
        if self.table.has_key(key):
            return self.table[key]
        else:
            return None



# 2-D array table for memoization of dynamic programming
def createTable(x, y):
    table = []
    for i in range(x+1):
        table.append([None] * (y+1))
    return table

def tableLookup(t, x, y):
    return t[x][y]

def tablePut(t, x, y, v):
    t[x][y] = v





#-------------------------------------------------------------
#                  string distance function
#-------------------------------------------------------------

### diff cache for AST nodes
strDistCache = Cache()
def clearStrDistCache():
    global strDistCache
    strDistCache = Cache()


### string distance function
def strDist(s1, s2):
    cached = strDistCache.get((s1, s2))
    if cached <> None:
        return cached

    if len(s1) > 100 or len(s2) > 100:
        if s1 <> s2:
            return 2.0
        else:
            return 0

    table = createTable(len(s1), len(s2))
    d = dist1(table, s1, s2)
    ret = div(2*d, len(s1) + len(s2))

    strDistCache.put((s1, s2), ret)
    return ret


# the main dynamic programming part
# similar to the structure of diffList
def dist1(table, s1, s2):
    def memo(v):
        tablePut(table, len(s1), len(s2), v)
        return v

    cached = tableLookup(table, len(s1), len(s2))
    if (cached <> None):
        return cached

    if s1 == '':
        return memo(len(s2))
    elif s2 == '':
        return memo(len(s1))
    else:
        if s1[0] == s2[0]:
            d0 = 0
        elif s1[0].lower() == s2[0].lower():
            d0 = 1
        else:
            d0 = 2

        d0 = d0 + dist1(table, s1[1:], s2[1:])
        d1 = 1 + dist1(table, s1[1:], s2)
        d2 = 1 + dist1(table, s1, s2[1:])
        return memo(min(d0, d1, d2))




#-------------------------------------------------------------
#                        diff of nodes
#-------------------------------------------------------------

stat.diffCount = 0
def diffNode(node1, node2, env1, env2, depth, move):

    # try substructural diff
    def trysub((changes, cost)):
        if not move:
            return (changes, cost)
        elif canMove(node1, node2, cost):
            return (changes, cost)
        else:
            mc1 = diffSubNode(node1, node2, env1, env2, depth, move)
            if mc1 <> None:
                return mc1
            else:
                return (changes, cost)

    if IS(node1, list) and not IS(node2, list):
        return diffNode(node1, [node2], env1, env2, depth, move)

    if not IS(node1, list) and IS(node2, list):
        return diffNode([node1], node2, env1, env2, depth, move)

    if (IS(node1, list) and IS(node2, list)):
        node1 = serializeIf(node1)
        node2 = serializeIf(node2)
        table = createTable(len(node1), len(node2))
        return diffList(table, node1, node2, env1, env2, 0, move)

    # statistics
    stat.diffCount += 1
    if stat.diffCount % 1000 == 0:
        dot()

    if node1 == node2:
        return (modifyNode(node1, node2, 0), 0)

    if IS(node1, Num) and IS(node2, Num):
        if node1.n == node2.n:
            return (modifyNode(node1, node2, 0), 0)
        else:
            return (modifyNode(node1, node2, 1), 1)

    if IS(node1, Str) and IS(node2, Str):
        cost = strDist(node1.s, node2.s)
        return (modifyNode(node1, node2, cost), cost)

    if (IS(node1, Name) and IS(node2, Name)):
        v1 = lookup(node1.id, env1)
        v2 = lookup(node2.id, env2)
        if v1 <> v2 or (v1 == None and v2 == None):
            cost = strDist(node1.id, node2.id)
            return (modifyNode(node1, node2, cost), cost)
        else:                           # same variable
            return (modifyNode(node1, node2, 0), 0)

    if (IS(node1, Attribute) and IS(node2, Name) or
        IS(node1, Name) and IS(node2, Attribute) or
        IS(node1, Attribute) and IS(node2, Attribute)):
        s1 = attr2str(node1)
        s2 = attr2str(node2)
        if s1 <> None and s2 <> None:
            cost = strDist(s1, s2)
            return (modifyNode(node1, node2, cost), cost)
        # else fall through for things like f(x).y vs x.y

    # if (IS(node1, ClassDef) and IS(node2, ClassDef)):
    #     (m1, c1) = diffNode(node1.bases, node2.bases, env1, env2, depth, move)
    #     (m2, c2) = diffNode(node1.body, node2.body, env1, env2, depth, move)
    #     (m3, c3) = diffNode(node1.decorator_list, node2.decorator_list,
    #                         env1, env2, depth, move)
    #     changes = append(m1, m2, m3)
    #     cost = c1 + c2 + c3 + strDist(node1.name, node2.name)
    #     return trysub((changes, cost))

    # if (IS(node1, FunctionDef) and IS(node2, FunctionDef)):
    #     return trysub(diffFunctionDef(node1, node2,
    #                                   env1, env2, depth, move))

    # if (IS(node1, Assign) and IS(node2, Assign)):
    #     (m1, c1) = diffNode(node1.targets, node2.targets,
    #                         env1, env2, depth, move)
    #     (m2, c2) = diffNode(node1.value, node2.value,
    #                         env1, env2, depth, move)
    #     return (append(m1, m2), c1 * ASSIGN_PENALTY + c2)

    # # flatten nested if nodes
    # if IS(node1, If) and IS(node2, If):
    #     seq1 = serializeIf(node1)
    #     seq2 = serializeIf(node2)
    #     if len(seq1) > 1 and len(seq2) > 1:
    #         return diffNode(seq1, seq2, env1, env2, depth, move)
    #     else:
    #         (m0, c0) = diffNode(node1.test, node2.test, env1, env2, depth, move)
    #         (m1, c1) = diffNode(node1.body, node2.body, env1, env2, depth, move)
    #         (m2, c2) = diffNode(node1.orelse, node2.orelse, env1, env2, depth, move)
    #         changes = append(m0, m1, m2)
    #         cost = c0 * IF_PENALTY + c1 + c2
    #         return trysub((changes, cost))

    if IS(node1, Module) and IS(node2, Module):
        return diffNode(node1.body, node2.body, env1, env2, depth, move)

    # other AST nodes
    if (IS(node1, AST) and IS(node2, AST) and
        type(node1) == type(node2)):

        fs1 = nodeFields(node1)
        fs2 = nodeFields(node2)
        changes, cost = nil, 0

        for i in xrange(len(fs1)):
            (m, c) = diffNode(fs1[i], fs2[i], env1, env2, depth, move)
            changes = append(m, changes)
            cost += c

        return trysub((changes, cost))

    if (type(node1) == type(node2) and
             isEmptyContainer(node1) and isEmptyContainer(node2)):
        return (modifyNode(node1, node2, 0), 0)

    # all unmatched types and unequal values
    return trysub((append(delNode(node1), insNode(node2)),
                   nodeSize(node1) + nodeSize(node2)))





###################### diff of a FunctionDef #####################

# separate out because it is too long

def diffFunctionDef(node1, node2, env1, env2, depth, move):

    # positionals
    len1 = len(node1.args.args)
    len2 = len(node2.args.args)

    if len1 < len2:
        minlen = len1
        rest = node2.args.args[minlen:]
    else:
        minlen = len2
        rest = node1.args.args[minlen:]

    ma = nil
    for i in xrange(minlen):
        a1 = node1.args.args[i]
        a2 = node2.args.args[i]
        if IS(a1, Name) and IS(a2, Name) and a1.id <> a2.id:
            env1 = ext(a1.id, a2, env1)
            env2 = ext(a2.id, a2, env2)
        (m1, c1) = diffNode(a1, a2, env1, env2, depth, move)
        ma = append(m1, ma)

    # handle rest of the positionals
    ca = 0
    if rest <> []:
        if len1 < len2:
            for arg in rest:
                ma = append(insNode(arg), ma)
                ca += nodeSize(arg)
        else:
            for arg in rest:
                ma = append(delNode(arg), ma)
                ca += nodeSize(arg)

    # vararg
    va1 = node1.varargName
    va2 = node2.varargName
    if va1 <> None and va2 <> None:
        if va1.id <> va2.id:
            env1 = ext(va1.id, va2, env1)
            env2 = ext(va2.id, va2, env2)
        cost = strDist(va1.id, va2.id)
        ma = append(modifyNode(va1, va2, cost), ma)
        ca += cost
    elif va1 <> None or va2 <> None:
        cost = nodeSize(va1) if va1 <> None else nodeSize(va2)
        ma = append(modifyNode(va1, va2, cost), ma)
        ca += cost

    # kwarg
    ka1 = node1.kwargName
    ka2 = node2.kwargName
    if ka1 <> None and ka2 <> None:
        if ka1.id <> ka2.id:
            env1 = ext(ka1.id, ka2, env1)
            env2 = ext(ka2.id, ka2, env2)
        cost = strDist(ka1.id, ka2.id)
        ma = append(modifyNode(ka1, ka2, cost), ma)
        ca += cost
    elif ka1 <> None or ka2 <> None:
        cost = nodeSize(ka1) if ka1 <> None else nodeSize(ka2)
        ma = append(modifyNode(ka1, ka2, cost), ma)
        ca += cost

    # defaults and body
    (md, cd) = diffNode(node1.args.defaults, node2.args.defaults,
                        env1, env2, depth, move)
    (mb, cb) = diffNode(node1.body, node2.body, env1, env2, depth, move)

    # sum up cost. penalize functions with different names.
    cost = ca + cd + cb + strDist(node1.name, node2.name)
    if node1.name <> node2.name:
        cost = cost * NAME_PENALTY

    return (append(ma, md, mb), cost)





########################## diff of a list ##########################

# diffList is the main part of dynamic programming

def diffList(table, ls1, ls2, env1, env2, depth, move):

    def memo(v):
        tablePut(table, len(ls1), len(ls2), v)
        return v

    def guess(table, ls1, ls2, env1, env2):
        (m0, c0) = diffNode(ls1[0], ls2[0], env1, env2, depth, move)
        (m1, c1) = diffList(table, ls1[1:], ls2[1:], env1, env2, depth, move)
        cost1 = c1 + c0

        if ((isFrame(ls1[0]) and
             isFrame(ls2[0]) and
             not nodeFramed(ls1[0], m0) and
             not nodeFramed(ls2[0], m0))):
            frameChange = modifyNode(ls1[0], ls2[0], c0)
        else:
            frameChange = nil

        # short cut 1 (func and classes with same names)
        if canMove(ls1[0], ls2[0], c0):
            return (append(frameChange, m0, m1), cost1)

        else:  # do more work
            (m2, c2) = diffList(table, ls1[1:], ls2, env1, env2, depth, move)
            (m3, c3) = diffList(table, ls1, ls2[1:], env1, env2, depth, move)
            cost2 = c2 + nodeSize(ls1[0])
            cost3 = c3 + nodeSize(ls2[0])

            if (not differentDef(ls1[0], ls2[0]) and
                cost1 <= cost2 and cost1 <= cost3):
                return (append(frameChange, m0, m1), cost1)
            elif (cost2 <= cost3):
                return (append(delNode(ls1[0]), m2), cost2)
            else:
                return (append(insNode(ls2[0]), m3), cost3)

    # cache look up
    cached = tableLookup(table, len(ls1), len(ls2))
    if (cached <> None):
        return cached

    if (ls1 == [] and ls2 == []):
        return memo((nil, 0))

    elif (ls1 <> [] and ls2 <> []):
        return memo(guess(table, ls1, ls2, env1, env2))

    elif ls1 == []:
        d = nil
        for n in ls2:
            d = append(insNode(n), d)
        return memo((d, nodeSize(ls2)))

    else: # ls2 == []:
        d = nil
        for n in ls1:
            d = append(delNode(n), d)
        return memo((d, nodeSize(ls1)))




###################### diff into a subnode #######################

# Subnode diff is only used in the moving phase. There is no
# need to compare the substructure of two nodes in the first
# run, because they will be reconsidered if we just consider
# them to be complete deletion and insertions.

def diffSubNode(node1, node2, env1, env2, depth, move):

    if (depth >= FRAME_DEPTH or
        nodeSize(node1) < FRAME_SIZE or
        nodeSize(node2) < FRAME_SIZE):
        return None

    if IS(node1, AST) and IS(node2, AST):

        if nodeSize(node1) == nodeSize(node2):
            return None

        if IS(node1, Expr):
            node1 = node1.value

        if IS(node2, Expr):
            node2 = node2.value

        if (nodeSize(node1) < nodeSize(node2)):
            for f in nodeFields(node2):
                (m0, c0) = diffNode(node1, f, env1, env2, depth+1, move)
                if canMove(node1, f, c0):
                    if not IS(f, list):
                        m1 = modifyNode(node1, f, c0)
                    else:
                        m1 = nil
                    framecost = nodeSize(node2) - nodeSize(node1)
                    m2 = loner(Change(None, node2, framecost, True))
                    return (append(m2, m1, m0), c0 + framecost)

        if (nodeSize(node1) > nodeSize(node2)):
            for f in nodeFields(node1):
                (m0, c0) = diffNode(f, node2, env1, env2, depth+1, move)
                if canMove(f, node2, c0):
                    framecost = nodeSize(node1) - nodeSize(node2)
                    if not IS(f, list):
                        m1 = modifyNode(f, node2, c0)
                    else:
                        m1 = nil
                    m2 = loner(Change(node1, None, framecost, True))
                    return (append(m2, m1, m0), c0 + framecost)

    return None





##########################################################################
##                          move detection
##########################################################################
def moveCandidate(node):
    return (isDef(node) or nodeSize(node) >= MOVE_SIZE)


stat.moveCount = 0
stat.moveSavings = 0
def getmoves(ds, round=0):

    dels = pylist(filterlist(lambda p: (p.cur == None and
                                        moveCandidate(p.orig) and
                                        not p.isFrame),
                             ds))
    adds = pylist(filterlist(lambda p: (p.orig == None and
                                        moveCandidate(p.cur) and
                                        not p.isFrame),
                             ds))

    # print "dels=", dels
    # print "adds=", adds

    matched = []
    newChanges, total = nil, 0

    print("\n[getmoves #%d] %d * %d = %d pairs of nodes to consider ..."
          % (round, len(dels), len(adds), len(dels) * len(adds)))

    for d0 in dels:
        for a0 in adds:
            (node1, node2) = (d0.orig, a0.cur)
            (changes, cost) = diffNode(node1, node2, nil, nil, 0, True)
            nterms = nodeSize(node1) + nodeSize(node2)

            if (canMove(node1, node2, cost) or
                nodeFramed(node1, changes) or
                nodeFramed(node2, changes)):

                matched.append(d0)
                matched.append(a0)
                adds.remove(a0)
                newChanges = append(changes, newChanges)
                total += cost

                if (not nodeFramed(node1, changes) and
                    not nodeFramed(node2, changes) and
                    isDef(node1) and isDef(node2)):
                    newChanges = append(modifyNode(node1, node2, cost),
                                        newChanges)

                stat.moveSavings += nterms
                stat.moveCount +=1
                if stat.moveCount % 1000 == 0:
                    dot()

                break

    print("\n\t%d matched pairs found with %d new changes."
          % (len(pylist(matched)), len(pylist(newChanges))))

    # print "matches=", matched
    # print "newChanges=", newChanges

    return (matched, newChanges, total)



# Get moves repeatedly because new moves may introduce new
# deletions and insertions.

def closure(res):
    (changes, cost) = res
    matched = None
    moveround = 1

    while moveround <= MOVE_ROUND and matched <> []:
        (matched, newChanges, c) = getmoves(changes, moveround)
        moveround += 1
        # print "matched:", matched
        # print "changes:", changes
        changes = filterlist(lambda c: c not in matched, changes)
        changes = append(newChanges, changes)
        savings = sum(map(lambda p: nodeSize(p.orig) + nodeSize(p.cur), matched))
        cost = cost + c - savings
    return (changes, cost)





#-------------------------------------------------------------
#                   improvements to the AST
#-------------------------------------------------------------

allNodes1 = set()
allNodes2 = set()

def improveNode(node, s, idxmap, filename, side):

    if IS(node, list):
        for n in node:
            improveNode(n, s, idxmap, filename, side)

    elif IS(node, AST):

        if side == 'left':
            allNodes1.add(node)
        else:
            allNodes2.add(node)

        findNodeStart(node, s, idxmap)
        findNodeEnd(node, s, idxmap)
        addMissingNames(node, s, idxmap)

        node.nodeSource = s
        node.fileName = filename

        for f in nodeFields(node):
            improveNode(f, s, idxmap, filename, side)



def improveAST(node, s, filename, side):
    idxmap = buildIndexMap(s)
    improveNode(node, s, idxmap, filename, side)




#-------------------------------------------------------------
#            finding start and end index of nodes
#-------------------------------------------------------------

def findNodeStart(node, s, idxmap):

    if hasattr(node, 'nodeStart'):
        return node.nodeStart

    elif IS(node, list):
        ret = findNodeStart(node[0], s, idxmap)

    elif IS(node, Module):
        ret = findNodeStart(node.body[0], s, idxmap)

    elif IS(node, BinOp):
        leftstart = findNodeStart(node.left, s, idxmap)
        if leftstart <> None:
            ret = leftstart
        else:
            ret = mapIdx(idxmap, node.lineno, node.col_offset)

    elif hasattr(node, 'lineno'):
        if node.col_offset >= 0:
            ret = mapIdx(idxmap, node.lineno, node.col_offset)
        else:                           # special case for """ strings
            i = mapIdx(idxmap, node.lineno, node.col_offset)
            while i > 0 and i+2 < len(s) and s[i:i+3] <> '"""':
                i -= 1
            ret = i
    else:
        ret = None

    if ret == None and hasattr(node, 'lineno'):
        raise TypeError("got None for node that has lineno", node)

    if IS(node, AST) and ret <> None:
        node.nodeStart = ret

    return ret




def findNodeEnd(node, s, idxmap):

    if hasattr(node, 'nodeEnd'):
        return node.nodeEnd

    elif IS(node, list):
        ret = findNodeEnd(node[-1], s, idxmap)

    elif IS(node, Module):
        ret = findNodeEnd(node.body[-1], s, idxmap)

    elif IS(node, Expr):
        ret = findNodeEnd(node.value, s, idxmap)

    elif IS(node, Str):
        i = findNodeStart(node, s, idxmap)
        if i+2 < len(s) and s[i:i+3] == '"""':
            q = '"""'
            i += 3
        elif s[i] == '"':
            q = '"'
            i += 1
        elif s[i] == "'":
            q = "'"
            i += 1
        else:
            print "illegal:", i, s[i]
        ret = endSeq(s, q, i)

    elif IS(node, Name):
        ret = findNodeStart(node, s, idxmap) + len(node.id)

    elif IS(node, Attribute):
        ret = endSeq(s, node.attr, findNodeEnd(node.value, s, idxmap))

    elif IS(node, FunctionDef):
        # addMissingNames(node, s, idxmap)
        # ret = findNodeEnd(node.nameName, s, idxmap)
        ret = findNodeEnd(node.body, s, idxmap)

    elif IS(node, Lambda):
        ret = findNodeEnd(node.body, s, idxmap)

    elif IS(node, ClassDef):
        # addMissingNames(node, s, idxmap)
        # ret = findNodeEnd(node.nameName, s, idxmap)
        ret = findNodeEnd(node.body, s, idxmap)

    elif IS(node, Call):
        ret = matchParen(s, '(', ')', findNodeEnd(node.func, s, idxmap))

    elif IS(node, Yield):
        ret = findNodeEnd(node.value, s, idxmap)

    elif IS(node, Return):
        if node.value <> None:
            ret = findNodeEnd(node.value, s, idxmap)
        else:
            ret = findNodeStart(node, s, idxmap) + len('return')

    elif IS(node, Print):
        ret = startSeq(s, '\n', findNodeStart(node, s, idxmap))

    elif (IS(node, For) or
          IS(node, While) or
          IS(node, If) or
          IS(node, IfExp)):
        if node.orelse <> []:
            ret = findNodeEnd(node.orelse, s, idxmap)
        else:
            ret = findNodeEnd(node.body, s, idxmap)

    elif IS(node, Assign) or IS(node, AugAssign):
        ret = findNodeEnd(node.value, s, idxmap)

    elif IS(node, BinOp):
        ret = findNodeEnd(node.right, s, idxmap)

    elif IS(node, BoolOp):
        ret = findNodeEnd(node.values[-1], s, idxmap)

    elif IS(node, Compare):
        ret = findNodeEnd(node.comparators[-1], s, idxmap)

    elif IS(node, UnaryOp):
        ret = findNodeEnd(node.operand, s, idxmap)

    elif IS(node, Num):
        ret = findNodeStart(node, s, idxmap) + len(str(node.n))

    elif IS(node, List):
        ret = matchParen(s, '[', ']', findNodeStart(node, s, idxmap));

    elif IS(node, Subscript):
        ret = matchParen(s, '[', ']', findNodeStart(node, s, idxmap));

    elif IS(node, Tuple):
        ret = findNodeEnd(node.elts[-1], s, idxmap)

    elif IS(node, Dict):
        ret = matchParen(s, '{', '}', findNodeStart(node, s, idxmap));

    elif IS(node, TryExcept):
        if node.orelse <> []:
            ret = findNodeEnd(node.orelse, s, idxmap)
        elif node.handlers <> []:
            ret = findNodeEnd(node.handlers, s, idxmap)
        else:
            ret = findNodeEnd(node.body, s, idxmap)

    elif IS(node, ExceptHandler):
        ret = findNodeEnd(node.body, s, idxmap)

    elif IS(node, Pass):
        ret = findNodeStart(node, s, idxmap) + len('pass')

    elif IS(node, Break):
        ret = findNodeStart(node, s, idxmap) + len('break')

    elif IS(node, Continue):
        ret = findNodeStart(node, s, idxmap) + len('continue')

    elif IS(node, Global):
        ret = startSeq(s, '\n', findNodeStart(node, s, idxmap))

    elif IS(node, Import):
        ret = findNodeStart(node, s, idxmap) + len('import')

    elif IS(node, ImportFrom):
        ret = findNodeStart(node, s, idxmap) + len('from')

    else:
        # print "[findNodeEnd] unrecognized node:", node, "type:", type(node)
        start = findNodeStart(node, s, idxmap)
        if start <> None:
            ret = start + 3
        else:
            ret = None

    if ret == None and hasattr(node, 'lineno'):
        raise TypeError("got None for node that has lineno", node)

    if IS(node, AST) and ret <> None:
        node.nodeEnd = ret

    return ret




#-------------------------------------------------------------
#                    adding missing Names
#-------------------------------------------------------------

def addMissingNames(node, s, idxmap):

    if hasattr(node, 'extraAttribute'):
        return

    if IS(node, list):
        for n in node:
            addMissingNames(n, s, idxmap)

    elif IS(node, ClassDef):
        start = findNodeStart(node, s, idxmap) + len('class')
        node.nameName = str2Name(s, start, idxmap)
        node._fields += ('nameName',)

    elif IS(node, FunctionDef):
        start = findNodeStart(node, s, idxmap) + len('def')
        node.nameName = str2Name(s, start, idxmap)
        node._fields += ('nameName',)

        if node.args.vararg <> None:
            if len(node.args.args) > 0:
                vstart = findNodeEnd(node.args.args[-1], s, idxmap)
            else:
                vstart = findNodeEnd(node.nameName, s, idxmap)
            vname = str2Name(s, vstart, idxmap)
            node.varargName = vname
        else:
            node.varargName = None
        node._fields += ('varargName',)

        if node.args.kwarg <> None:
            if len(node.args.args) > 0:
                kstart = findNodeEnd(node.args.args[-1], s, idxmap)
            else:
                kstart = findNodeEnd(node.varargName, s, idxmap)
            kname = str2Name(s, kstart, idxmap)
            node.kwargName = kname
        else:
            node.kwargName = None
        node._fields += ('kwargName',)

    elif IS(node, Attribute):
        start = findNodeEnd(node.value, s, idxmap)
        name = str2Name(s, start, idxmap)
        node.attrName = name
        node._fields = ('value', 'attrName')  # remove attr for node size accuracy

    elif IS(node, Compare):
        node.opsName = convertOps(node.ops, s,
                                  findNodeStart(node, s, idxmap), idxmap)
        node._fields += ('opsName',)

    elif (IS(node, BoolOp) or
          IS(node, BinOp) or
          IS(node, UnaryOp) or
          IS(node, AugAssign)):
        if hasattr(node, 'left'):
            start = findNodeEnd(node.left, s, idxmap)
        else:
            start = findNodeStart(node, s, idxmap)
        ops = convertOps([node.op], s, start, idxmap)
        node.opName = ops[0]
        node._fields += ('opName',)

    elif IS(node, Import):
        nameNames = []
        next = findNodeStart(node, s, idxmap) + len('import')
        name = str2Name(s, next, idxmap)
        while name <> None and next < len(s) and s[next] <> '\n':
            nameNames.append(name)
            next = name.nodeEnd
            name = str2Name(s, next, idxmap)
        node.nameNames = nameNames
        node._fields += ('nameNames',)

    node.extraAttribute = True



#-------------------------------------------------------------
#              utilities used by improve AST functions
#-------------------------------------------------------------

# find a sequence in a string s, returning the start point
def startSeq(s, pat, start):
    try:
        return s.index(pat, start)
    except ValueError:
        return len(s)



# find a sequence in a string s, returning the end point
def endSeq(s, pat, start):
    try:
        return s.index(pat, start) + len(pat)
    except ValueError:
        return len(s)



# find matching close paren from start
def matchParen(s, open, close, start):
    while s[start] <> open and start < len(s):
        start += 1
    if start >= len(s):
        return len(s)

    left = 1
    i = start + 1
    while left > 0 and i < len(s):
        if s[i] == open:
            left += 1
        elif s[i] == close:
            left -= 1
        i += 1
    return i



# build table for lineno <-> index oonversion
def buildIndexMap(s):
    line = 0
    col = 0
    idx = 0
    idxmap = [0]
    while idx < len(s):
        if s[idx] == '\n':
            idxmap.append(idx + 1)
            line += 1
        idx += 1
    return idxmap



# convert (line, col) to offset index
def mapIdx(idxmap, line, col):
    return idxmap[line-1] + col



# convert offset index into (line, col)
def mapLineCol(idxmap, idx):
    line = 0
    for start in idxmap:
        if idx < start:
            break
        line += 1
    col = idx - idxmap[line-1]
    return (line, col)



# convert string to Name
def str2Name(s, start, idxmap):
    i = start;
    while i < len(s) and not isAlpha(s[i]):
        i += 1
    startIdx = i
    ret = []
    while i < len(s) and isAlpha(s[i]):
        ret.append(s[i])
        i += 1
    endIdx = i
    id1 = ''.join(ret)

    if id1 == '':
        return None
    else:
        name = Name(id1, None)
        name.nodeStart = startIdx
        name.nodeEnd = endIdx
        name.lineno, name.col_offset = mapLineCol(idxmap, startIdx)
        return name



def convertOps(ops, s, start, idxmap):
    syms = map(lambda op: opsMap[type(op)], ops)
    i = start
    j = 0
    ret = []
    while i < len(s) and j < len(syms):
        oplen = len(syms[j])
        if s[i:i+oplen] == syms[j]:
            opName = Name(syms[j], None)
            opName.nodeStart = i
            opName.nodeEnd = i+oplen
            opName.lineno, opName.col_offset = mapLineCol(idxmap, i)
            ret.append(opName)
            j += 1
            i = opName.nodeEnd
        else:
            i += 1
    return ret


# lookup table for operators for convertOps
opsMap = {
    # compare:
    Eq     : '==',
    NotEq  : '<>',
    Lt     : '<',
    LtE    : '<=',
    Gt     : '>',
    GtE    : '>=',
    In     : 'in',
    NotIn  : 'not in',

    # BoolOp
    Or  : 'or',
    And : 'and',
    Not : 'not',

    # BinOp
    Add  : '+',
    Sub  : '-',
    Mult : '*',
    Div  : '/',
    Mod  : '%',

    # UnaryOp
    USub : '-',
    UAdd : '+',
}






#-------------------------------------------------------------
#                        HTML generation
#-------------------------------------------------------------


#-------------------- types and utilities ----------------------

class Tag:
    def __init__(self, tag, idx, start=-1):
        self.tag = tag
        self.idx = idx
        self.start = start
    def __repr__(self):
        return "tag:" + str(self.tag) + ":" + str(self.idx)



# escape for HTML
def escape(s):
    s = s.replace('"', '"')
    s = s.replace("'", ''')
    s = s.replace("<", '<')
    s = s.replace(">", '>')
    return s



uidCount = -1
uidHash = {}
def clearUID():
    global uidCount, uidHash
    uidCount = -1
    uidHash = {}


def uid(node):
    if uidHash.has_key(node):
        return uidHash[node]

    global uidCount
    uidCount += 1
    uidHash[node] = str(uidCount)
    return str(uidCount)



def lineId(lineno):
    return 'L' + str(lineno);


def qs(s):
    return "'" + s + "'"



#-------------------- main HTML generating function ------------------

def genHTML(text, changes, side):
    ltags = lineTags(text)
    ctags = changeTags(text, changes, side)
    ktags = keywordTags(side)
    body = applyTags(text, ltags + ctags + ktags, side)

    out = []
    out.append('<html>\n')
    out.append('<head>\n')
    out.append('<META http-equiv="Content-Type" content="text/html; charset=utf-8">\n')
    out.append('<LINK href="diff.css" rel="stylesheet" type="text/css">\n')
    out.append('<script type="text/javascript" src="nav.js"></script>\n')
    out.append('</head>\n')
    out.append('<body>\n')

    out.append('<pre>\n')
    out.append(body)
    out.append('</pre>\n')

    # out.append('</body>\n')
    # out.append('</html>\n')

    return ''.join(out)



# put the tags generated by changeTags into the text and create HTML
def applyTags(s, tags, side):
    tags = sorted(tags, key = lambda t: (t.idx, -t.start))
    curr = 0
    out = []
    for t in tags:
        while curr < t.idx and curr < len(s):
            out.append(escape(s[curr]))
            curr += 1
        out.append(t.tag)

    while curr < len(s):
        out.append(escape(s[curr]))
        curr += 1
    return ''.join(out)




#--------------------- tag generation functions ----------------------

def changeTags(s, changes, side):
    tags = []
    for r in changes:
        key = r.orig if side == 'left' else r.cur
        if hasattr(key, 'lineno'):
            start = nodeStart(key)
            if IS(key, FunctionDef):
                end = start + len('def')
            elif IS(key, ClassDef):
                end = start + len('class')
            else:
                end = nodeEnd(key)

            if r.orig <> None and r.cur <> None:
                # <a ...> for change and move
                tags.append(Tag(linkTagStart(r, side), start))
                tags.append(Tag("</a>", end, start))
            else:
                # <span ...> for deletion and insertion
                tags.append(Tag(spanStart(r), start))
                tags.append(Tag('</span>', end, start))

    return tags



def lineTags(s):
    out = []
    lineno = 1;
    curr = 0
    while curr < len(s):
        if curr == 0 or s[curr-1] == '\n':
            out.append(Tag('<div class="line" id="L' + str(lineno) + '">', curr))
            out.append(Tag('<span class="lineno">' + str(lineno) + ' </span>', curr))
        if s[curr] == '\n':
            out.append(Tag('</div>', curr))
            lineno += 1
        curr += 1
    out.append(Tag('</div>', curr))
    return out



def keywordTags(side):
    tags = []
    allNodes = allNodes1 if side == 'left' else allNodes2
    for node in allNodes:
        if type(node) in keywordMap:
            kw = keywordMap[type(node)]
            start = nodeStart(node)
            if src(node)[:len(kw)] == kw:
                startTag = (Tag('<span class="keyword">', start))
                tags.append(startTag)
                endTag = Tag('</span>', start + len(kw), start)
                tags.append(endTag)
    return tags


def spanStart(diff):
    if diff.cur == None:
        cls = "deletion"
    else:
        cls = "insertion"
    text = escape(describeChange(diff))
    return '<span class="' + cls + '" title="' + text + '">'



def linkTagStart(diff, side):
    if side == 'left':
        me, other = diff.orig, diff.cur
    else:
        me, other = diff.cur, diff.orig

    text = escape(describeChange(diff))
    if diff.cost > 0:
        cls = "change"
    else:
        cls = "move"

    return ('<a id="' + uid(me) + '" '
            + ' class="' + cls + '" '
            + ' title="' + text + '" '
            + 'onclick="highlight('
                          + qs(uid(me)) + ","
                          + qs(uid(other)) + ","
                          + qs(lineId(me.lineno)) + ","
                          + qs(lineId(other.lineno)) + ')">')


keywordMap = {
    FunctionDef : 'def',
    ClassDef    : 'class',
    For         : 'for',
    While       : 'while',
    If          : 'if',
    With        : 'with',
    Return      : 'return',
    Yield       : 'yield',
    Global      : 'global',
    Raise       : 'raise',
    Pass        : 'pass',
    TryExcept   : 'try',
    TryFinally  : 'try',
    }




# human readable description of node

def describeNode(node):

    def code(s):
        return "'" + s + "'"

    def short(node):
        if IS(node, Module):
            ret = "module"
        elif IS(node, Import):
            ret = "import statement"
        elif IS(node, Name):
            ret = code(node.id)
        elif IS(node, Attribute):
            ret = code(short(node.value) + "." + short(node.attrName))
        elif IS(node, FunctionDef):
            ret = "function " + code(node.name)
        elif IS(node, ClassDef):
            ret = "class " + code(node.name)
        elif IS(node, Call):
            ret = "call to " + code(short(node.func))
        elif IS(node, Assign):
            ret = "assignment"
        elif IS(node, If):
            ret = "if statement"
        elif IS(node, While):
            ret = "while loop"
        elif IS(node, For):
            ret = "for loop"
        elif IS(node, Yield):
            ret = "yield"
        elif IS(node, TryExcept) or IS(node, TryFinally):
            ret = "try statement"
        elif IS(node, Compare):
            ret = "comparison " + src(node)
        elif IS(node, Return):
            ret = "return " + short(node.value)
        elif IS(node, Print):
            ret = ("print " + short(node.dest) +
                   ", " if (node.dest!=None) else "" + printList(node.values))
        elif IS(node, Expr):
            ret = "expression " + short(node.value)
        elif IS(node, Num):
            ret = str(node.n)
        elif IS(node, Str):
            if len(node.s) > 20:
                ret = "string " + code(node.s[:20]) + "..."
            else:
                ret = "string " + code(node.s)
        elif IS(node, Tuple):
            ret = "tuple (" + src(node) + ")"
        elif IS(node, BinOp):
            ret = (short(node.left) + " " +
                   node.opName.id + " " + short(node.right))
        elif IS(node, BoolOp):
            ret = src(node)
        elif IS(node, UnaryOp):
            ret = node.opName.id + " " + short(node.operand)
        elif IS(node, Pass):
            ret = "pass"
        elif IS(node, list):
            ret = map(short, node)
        else:
            ret = str(type(node))
        return ret

    ret = short(node)
    if hasattr(node, 'lineno'):
        ret = re.sub(" *(line [0-9]+)", '', ret)
        return ret + " (line " + str(node.lineno) + ")"
    else:
        return ret




# describe a change in a human readable fashion
def describeChange(diff):

    ratio = diff.similarity()
    sim = str(ratio)

    if ratio == 1.0:
        sim = " (unchanged)"
    else:
        sim = " (similarity %.1f%%)" % (ratio * 100)

    if diff.isFrame:
        wrap = "wrap "
    else:
        wrap = ""

    if diff.cur == None:
        ret = wrap + describeNode(diff.orig) + " deleted"
    elif diff.orig == None:
        ret = wrap + describeNode(diff.cur) + " inserted"
    elif nodeName(diff.orig) <> nodeName(diff.cur):
        ret = (describeNode(diff.orig) +
               " renamed to " + describeNode(diff.cur) + sim)
    elif diff.cost == 0 and diff.orig.lineno <> diff.cur.lineno:
        ret = (describeNode(diff.orig) +
               " moved to " + describeNode(diff.cur) + sim)
    elif diff.cost == 0:
        ret = describeNode(diff.orig) + " unchanged"
    else:
        ret = (describeNode(diff.orig) +
               " changed to " + describeNode(diff.cur) + sim)

    return ret





#-------------------------------------------------------------
#                     main HTML based command
#-------------------------------------------------------------

def diff(file1, file2, move=True):

    import time
    print("\nJob started at %s, %s\n" % (time.ctime(), time.tzname[0]))
    startTime = time.time()
    checkpoint(startTime)

    cleanUp()

    # base files names
    baseName1 = pyFileName(file1)
    baseName2 = pyFileName(file2)

    # get AST of file1
    f1 = open(file1, 'r');
    lines1 = f1.read()
    f1.close()
    node1 = parse(lines1)
    improveAST(node1, lines1, file1, 'left')

    # get AST of file2
    f2 = open(file2, 'r');
    lines2 = f2.read()
    f2.close()
    node2 = parse(lines2)
    improveAST(node2, lines2, file2, 'right')


    print("[parse] finished in %s. Now start to diff." % sec2min(checkpoint()))

    # get the changes

    (changes, cost) = diffNode(node1, node2, nil, nil, 0, False)

    print ("\n[diff] processed %d nodes in %s."
           % (stat.diffCount, sec2min(checkpoint())))

    if move:
#        print "changes:", changes
        (changes, cost) = closure((changes, cost))

        print("\n[closure] finished in %s." % sec2min(checkpoint()))



    #---------------------- print final stats ---------------------
    size1 = nodeSize(node1)
    size2 = nodeSize(node2)
    total = size1 + size2

    report = ""
    report += ("\n--------------------- summary -----------------------") + "\n"
    report += ("- total changes (chars):  %d" % cost)                  + "\n"
    report += ("- total code size:        %d (left: %d  right: %d)"
               % (total, size1, size2))                                + "\n"
    report += ("- total moved pieces:     %d" % stat.moveCount)        + "\n"
    report += ("- percentage of change:   %.1f%%"
               % (div(cost, total) * 100))                             + "\n"
    report += ("-----------------------------------------------------")   + "\n"

    print report


    #---------------------- generation HTML ---------------------
    # write left file
    leftChanges = filterlist(lambda p: p.orig <> None, changes)
    html1 = genHTML(lines1, leftChanges, 'left')

    outname1 = baseName1 + '.html'
    outfile1 = open(outname1, 'w')
    outfile1.write(html1)
    outfile1.write('<div class="stats"><pre class="stats">')
    outfile1.write(report)
    outfile1.write('</pre></div>')
    outfile1.write('</body>\n')
    outfile1.write('</html>\n')
    outfile1.close()


    # write right file
    rightChanges = filterlist(lambda p: p.cur <> None, changes)
    html2 = genHTML(lines2, rightChanges, 'right')

    outname2 = baseName2 + '.html'
    outfile2 = open(outname2, 'w')
    outfile2.write(html2)
    outfile2.write('<div class="stats"><pre class="stats">')
    outfile2.write(report)
    outfile2.write('</pre></div>')
    outfile2.write('</body>\n')
    outfile2.write('</html>\n')
    outfile2.close()


    # write frame file
    framename = baseName1 + "-" + baseName2 + ".html"
    framefile = open(framename, 'w')
    framefile.write('<frameset cols="50%,50%">\n')
    framefile.write('<frame name="left" src="' + baseName1 + '.html">\n')
    framefile.write('<frame name="right" src="' + baseName2 + '.html">\n')
    framefile.write('</frameset>\n')
    framefile.close()

    dur = time.time() - startTime
    print("\n[summary] Job finished at %s, %s" %
          (time.ctime(), time.tzname[0]))
    print("\n\tTotal duration: %s" % sec2min(dur))




def cleanUp():
    clearStrDistCache()
    clearUID()

    global allNodes1, allNodes2
    allNodes1 = set()
    allNodes2 = set()

    stat.diffCount = 0
    stat.moveCount = 0
    stat.moveSavings = 0



def sec2min(s):
    if s < 60:
        return ("%.1f seconds" % s)
    else:
        return ("%.1f minutes" % div(s, 60))



lastCheckpoint = None
def checkpoint(init=None):
    import time
    global lastCheckpoint
    if init <> None:
        lastCheckpoint = init
        return None
    else:
        dur = time.time() - lastCheckpoint
        lastCheckpoint = time.time()
        return dur




#-------------------------------------------------------------
#                      text-based interfaces
#-------------------------------------------------------------

## text-based main command
def printDiff(file1, file2):
    (m, c) = diffFile(file1, file2)
    print "----------", file1, "<<<", c, ">>>", file2, "-----------"

    ms = pylist(m)
    ms = sorted(ms, key=lambda d: nodeStart(d.orig))
    print "\n-------------------- changes(", len(ms), ")---------------------- "
    for m0 in ms:
        print m0

    print "\n-------------------  end  ----------------------- "




def diffFile(file1, file2):
    node1 = parseFile(file1)
    node2 = parseFile(file2)
    return closure(diffNode(node1, node2, nil, nil, 0, False))




# printing support for debugging use
def iter_fields(node):
    """Iterate over all existing fields, excluding 'ctx'."""
    for field in node._fields:
        try:
            if field <> 'ctx':
                yield field, getattr(node, field)
        except AttributeError:
            pass


def dump(node, annotate_fields=True, include_attributes=False):
    def _format(node):
        if isinstance(node, AST):
            fields = [(a, _format(b)) for a, b in iter_fields(node)]
            rv = '%s(%s' % (node.__class__.__name__, ', '.join(
                ('%s=%s' % field for field in fields)
                if annotate_fields else
                (b for a, b in fields)
            ))
            if include_attributes and node._attributes:
                rv += fields and ', ' or ' '
                rv += ', '.join('%s=%s' % (a, _format(getattr(node, a)))
                                for a in node._attributes)
            return rv + ')'
        elif isinstance(node, list):
            return '[%s]' % ', '.join(_format(x) for x in node)
        return repr(node)
    if not isinstance(node, AST):
        raise TypeError('expected AST, got %r' % node.__class__.__name__)
    return _format(node)

def printList(ls):
    if (ls == None or ls == []):
        return ""
    elif (len(ls) == 1):
        return str(ls[0])
    else:
        return str(ls)




# for debugging use
def printAst(node):
    if (IS(node, Module)):
        ret = "module:" + str(node.body)
    elif (IS(node, Name)):
        ret = str(node.id)
    elif (IS(node, Attribute)):
        if hasattr(node, 'attrName'):
            ret = str(node.value) + "." + str(node.attrName)
        else:
            ret = str(node.value) + "." + str(node.attr)
    elif (IS(node, FunctionDef)):
        if hasattr(node, 'nameName'):
            ret = "fun:" + str(node.nameName)
        else:
            ret = "fun:" + str(node.name)
    elif (IS(node, ClassDef)):
        ret = "class:" + str(node.name)
    elif (IS(node, Call)):
        ret = "call:" + str(node.func) + ":(" + printList(node.args) + ")"
    elif (IS(node, Assign)):
        ret = "(" + printList(node.targets) + " <- " + printAst(node.value) + ")"
    elif (IS(node, If)):
        ret = "if " + str(node.test) + ":" + printList(node.body) + ":" + printList(node.orelse)
    elif (IS(node, Compare)):
        ret = str(node.left) + ":" + printList(node.ops) + ":" + printList(node.comparators)
    elif (IS(node, Return)):
        ret = "return " + repr(node.value)
    elif (IS(node, Print)):
        ret = "print(" + (str(node.dest) + ", " if (node.dest!=None) else "") + printList(node.values) + ")"
    elif (IS(node, Expr)):
        ret = "expr:" + str(node.value)
    elif (IS(node, Num)):
        ret = "num:" + str(node.n)
    elif (IS(node, Str)):
        ret = 'str:"' + str(node.s) + '"'
    elif (IS(node, BinOp)):
        ret = str(node.left) + " " + str(node.op) + " " + str(node.right)
    elif (IS(node, Add)):
        ret = '+'
    elif (IS(node, Mult)):
        ret = '*'
    elif IS(node, NotEq):
        ret = '<>'
    elif (IS(node, Eq)):
        ret = '=='
    elif (IS(node, Pass)):
        ret = "pass"
    elif IS(node,list):
        ret = printList(node)
    else:
        ret = str(type(node))

    if hasattr(node, 'lineno'):
        return re.sub("@[0-9]+", '', ret) + "@" + str(node.lineno)
    elif hasattr(node, 'nodeStart'):
        return re.sub("@[0-9]+", '', ret) + "%" + str(nodeStart(node))
    else:
        return ret


def installPrinter():
    import inspect, ast
    for name, obj in inspect.getmembers(ast):
        if (inspect.isclass(obj) and not (obj == AST)):
            obj.__repr__ = printAst

installPrinter()

# demo
# diff('demos/demo1.py', 'demos/demo2.py')