# -*- Mode: Python -*-

#
# analysis on the lambda tree - inlining, simplification, etc...
#

import nodes
import itypes

from pdb import set_trace as trace

is_a = isinstance

class UnboundVariableError (Exception):
    pass

# XXX this file needs a lot of work now.  much of the nastier stuff in here
#  has now been obsoleted by the typing phase.

# something to think about.
# fix and let* are *very* close now.  In fact, once the node tree leaves
#  this file they become identical.  So now the question is, can we push
#  that transformation even earlier?  It might simplify some of this code,
#  and may open up some interesting transformations that could flatten
#  the lexical depth of the output...

class analyzer:

    """identify the definition and use of variables (and functions)."""

    def __init__ (self, context):
        self.node_counter = 0
        self.context = context
        self.vars = context.var_dict
        self.constants = {}
        self.inline = not context.noinline
        self.verbose = context.verbose
        self.inline_multiplier = {}

    def analyze (self, root):
        # find aliases
        self.find_aliases (root)
        # perform simple transformations
        root = self.optimize_nvcase (root)
        root = self.transform (root, 0)
        root = self.transform (root, 1)
        self.find_recursion (root)
        if self.verbose:
            print 'calls:'
            self.print_calls (root)
        self.find_applications (root)
        self.escape_analysis (root)
        if self.inline:
            # XXX this is already being done in typing, let's combine them.
            self.call_graph = self.build_call_graph (root)
            root = self.find_inlines (root)
            # transform again
            root = self.transform (root, 1)
            # trim again
            self.find_applications (root)
            if self.verbose:
                print 'after inlining, then pruning again'
            root = self.prune_fixes (root)
            # repeat this with new nodes...
            self.find_recursion (root)
            # re-do escape analysis
            self.escape_analysis (root)


        # mark leaf expressions
        self.find_leaves (root)


        for node in root:
            node.fix_attribute_names()
            if node.is_a ('function'):
                node.calls = self.get_fun_calls (node)

        return root

    def transform (self, node, stage):
        name = 'transform_%d_%s' % (stage, node.kind)
        probe = getattr (self, name, None)
        if probe:
            node = probe (node)
        new_subs = [self.transform (sub, stage) for sub in node.subs]
        node = nodes.node (node.kind, node.params, new_subs, node.type)
        if node.is_a ('fix'):
            # update function slots in every vardef
            names = node.get_names()
            inits  = node.subs[:-1]
            for i in range (len (names)):
                if inits[i].is_a ('function'):
                    names[i].function = inits[i]
        return node

    def transform_0_primapp (self, node):
        if node.name == '&vcase':
            return self.transform_pvcase (node)
        else:
            return node

    def transform_1_conditional (self, node):
        # (if #t x y) => x
        [test_exp, then_exp, else_exp] = node.subs
        if test_exp.is_a ('literal') and test_exp.params[0] == 'bool':
            if test_exp.params[1] == 'true':
                return then_exp
            else:
                return else_exp
        else:
            return node

    # XXX any reason the same wouldn't work for <fix>?
    def transform_1_let_splat (self, node):
        # coalesce cascading let*
        names = node.params
        inits = node.subs[:-1]
        body = node.subs[-1]
        # this is generated often by vcase: (let (x <init>) x)
        if len(names) == 1 and body.is_a ('varref') and body.params == names[0].name:
            return inits[0]
        elif body.is_a ('let_splat'):
            names2 = body.params
            inits2 = body.subs[:-1]
            body2  = body.subs[-1]
            return nodes.let_splat (
                names + names2,
                [self.transform (x, 1) for x in inits + inits2],
                self.transform (body2, 1),
                type=body2.type
                )
        else:
            # search for any let* in the inits...
            #   (let* ((x (let* ((a ...) (b ...)) <body2>))) <body1>)
            #     => (let* ((a ...) (b ...) (x <body2>)) <body1>) ???
            for i in range (len (inits)):
                if inits[i].is_a ('let_splat'):
                    n2 = inits[i]
                    # insert this let* above this variable
                    names2 = n2.params
                    inits2 = n2.subs[:-1]
                    body2  = n2.subs[-1]
                    return self.transform (
                        nodes.let_splat (
                            names[:i] + names2 + names[i:],
                            inits[:i] + inits2 + [body2] + inits[i+1:],
                            body,
                            type=body.type
                            ),
                        1
                        )
            else:
                return node

    def find_aliases (self, root):
        # find aliases - e.g., "(define THING 0)" or "(define PLUS +)"
        # this is a bit of a hack - this should probably be folded into
        #   a more general mechanism (like the inliner, duh).
        for node in root:
            if node.is_a ('fix'):
                for i in range (len (node.names)):
                    name = node.names[i]
                    init = node.inits[i]
                    if init.is_a ('literal'):
                        # strings are a special case - they're not 'simple' literals.
                        if not name.assigns and init.ltype != 'string':
                            name.alias = init
                    elif init.is_a ('varref'):
                        # neither is assigned to, should be safe
                        if not name.assigns and not init.var.assigns:
                            name.alias = init

    def transform_0_varref (self, node):
        var = self.vars[node.name]
        if var.alias is not None:
            #print 'alias %r => %r' % (var, var.alias)
            return var.alias
        else:
            return node

    def transform_1_fix (self, node):
        # coalesce cascading <fix>
        # (fix (a b c) (fix (d e f) ...))
        # => (fix (a b c d e f) ...)
        names = node.params
        inits = node.subs[:-1]
        body = node.subs[-1]
        if body.is_a ('fix'):
            names2 = body.params
            inits2 = body.subs[:-1]
            body2  = body.subs[-1]
            result = nodes.fix (
                names + names2,
                [self.transform (x, 1) for x in inits + inits2],
                self.transform (body2, 1),
                type=body2.type
                )
            result.fix_attribute_names()
            return result
        else:
            return node

    def transform_1_sequence (self, node):
        if len (node.subs) == 1:
            # (begin x) => x
            return node.subs[0]
        else:
            # (begin a0 a1 a2 ... (begin b1 b2) ...)
            # => (begin a0 a1 a2 ... b1 b2 ...)
            # this has no real effect, but feels good, doesn't it?
            subs = []
            for sub in node.subs:
                if sub.is_a ('sequence'):
                    subs.extend (sub.subs)
                else:
                    subs.append (sub)
            return nodes.sequence (subs)

    def transform_pvcase (self, node, val=None):
        # vcase statements (including those generated by the match compiler) come into
        #   analyze.py as %vcase prim calls, and are here transformed into a single node
        #   which implements the n-way branch (as opposed to a mess of embedded lambda's).
        label, arity = node.name_params
        success, failure, value = node.subs
        if val is None:
            val = value
        if failure.body.is_a ('primapp') and failure.body.name == '&vcase':
            vcase = self.transform_pvcase (failure.body, val)
        elif failure.body.is_a ('primapp') and failure.body.name == '%vfail':
            vcase = nodes.pvcase (val, [], [])
        else:
            # since <failure> cannot bind any variables, we just beta reduce it here.
            vcase = nodes.pvcase (val, [], [failure.body])
        # filter out don't-care variable bindings
        n = len (success.formals)
        arity = n
        formals = []
        kept = []
        for i in range (n):
            f = success.formals[i]
            if not f.name.startswith ('_'):
                formals.append (f)
                kept.append (i)
        # ugh, always a bad idea to edit nodes in place.
        success.formals = formals
        success.params[1] = formals
        alt_formals = (label, n, success.formals)
        # don't trigger this for variant records!
        inits = []
        if arity > 1:
            for i in kept:
                inits.append (nodes.primapp ('&vget', [val], (label, arity, i) ))
        elif arity == 0:
            inits = []
        else:
            if 0 in kept:
                inits.append (nodes.primapp ('&vget', [val], (label, arity, 0)))
            else:
                inits = []
        for x in inits:
            x.fix_attribute_names()
        clause = nodes.application (success, inits)
        vcase.params.insert (0, alt_formals)
        vcase.subs.insert (1, clause)
        return vcase

    def replace (self, orig_node, fun):
        # apply replacement-fun() to all of <node>
        node = fun (orig_node)
        new_subs = []
        size = 1
        for sub in node.subs:
            new_sub = self.replace (sub, fun)
            new_subs.append (new_sub)
            size += new_sub.size
        node.subs = new_subs
        # update the size
        node.size = size
        # catch updates to extra meta-data
        # XXX disgusting, all of it
        if node.is_a ('fix'):
            names = node.get_names()
            inits  = node.subs[:-1]
            for i in range (len (names)):
                if inits[i].is_a ('function') and names[i].function.serial != inits[i].serial:
                    old = names[i].function
                    new = inits[i]
                    names[i].function = new
        return node

    def note_funcall (self, name):
        if self.calls.has_key (name):
            self.calls[name] = self.calls[name] + 1
        else:
            self.calls[name] = 1

    def get_fun_calls (self, name):
        mult = self.inline_multiplier.get (name, 1)
        return mult * self.calls.get (name, 0)

    def find_recursion (self, exp):

        self.calls = {}

        def lookup_fun (fun, fenv):
            while fenv:
                entry, fenv = fenv
                if fun is entry:
                    return True
            return False

        def search (exp, fenv):
            if exp.is_a ('function'):
                fenv = (exp, fenv)
            elif exp.is_a ('application'):
                if exp.get_rator().is_a ('varref'):
                    ref = exp.get_rator()
                    name = ref.params
                    var = self.vars[name]
                    if var.function:
                        fun = var.function
                        if lookup_fun (fun, fenv):
                            # mark both the function and the application as recursive
                            fun.params[2] = True
                            exp.params = True
                        else:
                            exp.params = False
                        exp.function = fun
                        self.note_funcall (name)
                    else:
                        exp.function = None
                else:
                    exp.function = None
            for sub in exp.subs:
                search (sub, fenv)
        search (exp, None)

    # XXX shouldn't be needed, use <context.dep_graph> instead.
    # YYY not necessarily - dep_graph records dependencies on things other than funcalls.
    def build_call_graph (self, root):
        call_graph = {}
        def search (exp, this_fun):
            if exp.is_a ('application') and exp.get_rator().is_a ('varref'):
                ref = exp.get_rator()
                name = ref.params
                this_fun.add (name)
            elif exp.is_a ('function') and exp.params[0]:
                name = exp.params[0]
                # i.e., a named function
                this_fun = set()
                call_graph[name] = this_fun
            for sub in exp.subs:
                search (sub, this_fun)
        call_graph['top'] = set()
        search (root, call_graph['top'])
        return call_graph

    # XXX use context.dep_graph, or context.scc_graph, which have all the cycles for us already.
    # XXX this function is called by the inliner, which is trying to decide if this is a
    #     recursive call or not.  HOWEVER, it really wants to be able to distinguish between
    #     a recursive call, and a call to a recursive function.
    def is_recursive (self, name):
        # this is used by the inliner to decide whether to inline a small
        #   function - rather than computing the full transitive closure,
        #   we'll check only candidate functions...
        class FoundIt:
            pass

        def search (name, needle, seen):
            seen.add (name)
            try:
                for callee in self.call_graph[name]:
                    if callee == needle:
                        raise FoundIt
                    if callee not in seen:
                        search (callee, needle, seen)
            except KeyError:
                # XXX for now, ignore unknown functions
                #  [i.e., pretend they can't be recursive...]
                pass

        try:
            search (name, name, set())
        except FoundIt:
            return True
        else:
            return False

    def print_calls (self, root):
        from pprint import pprint as pp
        pp (self.calls)

    def lookup_var (self, node):
        name = node.params
        return self.vars[name]

    # ad-hoc 'tree shaker'
    # we only want to descend into code that's actually called.
    # so rather than walk every function in a <fix>, we start
    # from the outermost body, and follow every chain of funcalls
    # from there.

    def get_initial_expressions (self, node):
        # collect all the expressions that will execute when this node is evaluated.
        # specifically, this is the first step of the tree shaker.
        nodes = []
        if node.one_of ('fix', 'let_splat'):
            # initial expressions consist of the body, and any non-lambda <inits>
            nodes.append (node.get_body())
            inits = node.subs[:-1]
            for init in inits:
                if not init.is_a ('function'):
                    # XXX we should really check that this variable is actually *used*
                    nodes.append (init)
            return nodes
        else:
            return [node]

    def find_applications (self, root):
        # XXX this method really needs a different name - it's more like
        #   'walk applications for tree shaking..'
        to_scan = {}
        # look at the body of root - find all referenced (named) functions
        initial_expressions = self.get_initial_expressions (root)
        for exp in initial_expressions:
            for node in exp:
                if node.one_of ('varref', 'varset'):
                    var = self.lookup_var (node)
                    if var.function:
                        fun = var.function
                        fun.params[0] = var.name # alpha conversion
                        to_scan[var] = fun
                    else:
                        to_scan[var] = None
        #print 'find_applications, to_scan=', to_scan
        # find all (named) functions referenced by those in <to_scan>
        seen = to_scan.copy()
        pass_num = 1
        while len(to_scan):
            to_scan_2 = {}
            #print 'pass #%d: %r' % (pass_num, to_scan.keys())
            for name, fun in to_scan.iteritems():
                if fun:
                    for node in fun.get_body():
                        if node.one_of ('varref', 'varset'):
                            var = self.lookup_var (node)
                            if var.function:
                                fun = var.function
                                fun.params[0] = var.name # alpha conversion
                                if not seen.has_key (var):
                                    to_scan_2[var] = fun
                                    seen[var] = fun
                            else:
                                seen[var] = None
                else:
                    seen[name] = None
            pass_num += 1
            to_scan = to_scan_2
        # ok, now <seen> has every (named) called function?
        # we can now start removing uncalled funs from <fix>
        pruned = []
        for node in root:
            if node.is_a ('fix'):
                # Warning: this edits the node in place
                keep = []
                names = node.params
                funs = node.subs[:-1]
                body = node.subs[-1]
                for i in range (len (names)):
                    if seen.has_key (names[i]):
                        keep.append ((names[i], funs[i]))
                    else:
                        pruned.append (names[i])
                node.params = [x[0] for x in keep]
                node.subs   = [x[1] for x in keep] + [body]
        if self.verbose:
            print 'pruned: ', pruned
            print 'kept:   ', seen.keys()
        # trim the global variable map
        for prune in pruned:
            del self.vars[prune.name]

    def prune_fixes (self, root):
        # now prune empty fixes
        def prune_fix (node):
            if node.one_of ('fix', 'let_splat') and not node.get_names():
                return prune_fix (node.get_body())
            else:
                return node
        return self.replace (root, prune_fix)

    inline_threshold = 13

    def find_inlines (self, root):

        def replacer (node):
            if node.is_a ('application'):
                rator = node.get_rator()
                if rator.is_a ('varref'):
                    name = rator.params
                    var = self.lookup_var (rator)
                    fun = var.function
                    # (<varref xxx> ...) doesn't always refer to a known
                    #  fun, in this case calls == 0...
                    calls = self.get_fun_calls (name)
                    # don't inline functions starting with magical '^' character
                    # XXX eventually this will be replaced with some
                    #     kind of compile-time-environment mechanism
                    if (not name.startswith ('^')
                        and calls > 0
                        and ((fun.size <= self.inline_threshold or ((calls == 1) and not fun.escapes))
                             and not self.is_recursive (name))
                        ):
                        if calls > 1:
                            # set the inline multiplier for funs called by this one.
                            self.set_multiplier (name, calls)
                        node.function = fun
                        result = self.inline_application (node)
                        if result.is_a ('application'):
                            # sneaky!
                            return replacer (result)
                        else:
                            return result
                    else:
                        return node
                elif rator.is_a ('function'):
                    node.function = rator
                    # XXX this isn't *always* a good idea, because
                    #   we might duplicate the args.  This needs to be smarter. [could
                    #   we just put the smarts in let_splat and turn this into that?]
                    #  *or* we can look at the size of the args and the ref-count of
                    #   each variable, and in some cases turn it into a let (hoping for
                    #   let-reg).
                    result = self.inline_application (node)
                    if result.is_a ('application'):
                        #print 'inlining lambda...'
                        #node.pprint()
                        return replacer (result)
                    else:
                        return result
                else:
                    return node
            else:
                return node

        # now call the replacer
        return self.replace (root, replacer)

    def set_multiplier (self, name, calls):
        # when we inline <name>, each function that it calls must have its call-count
        #  raised by a factor of <calls>.
        for callee in self.call_graph[name]:
            # only record the multiplier the first time <name> is inlined.
            if not self.inline_multiplier.has_key (callee):
                self.inline_multiplier[callee] = calls

    def assigned (self, var):
        return len (self.context.var_dict[var.name].assigns)

    def safe_nvget_inline (self, rands):
        r0 = rands[0]
        if r0.is_a ('primapp') and r0.name.startswith ('%nvget/'):
            # make sure the variable is not assigned to...
            if self.assigned (r0.args[0]) == 0:
                return True
        return False

    rename_counter = 0

    def inline_application (self, node):
        # ok, we've decided to inline this node.
        # now we pick which of the two kinds of inlining we'll use.
        # 1) if the arguments are all simple (lit or varref), then we inline textually.
        # 2) if any of the arguments are complex, then we translate to let*.
        # 3) if a complex argument is only referred to once, treat it like a simple arg.
        #
        # XXX might we consider a primapp a simple arg? (say, depending on its size?)
        #
        simple = []
        complex = []
        rator = node.get_rator()
        rands = node.get_rands()
        fun = node.function
        # alpha convert a copy of the function
        body = self.instantiate (fun)
        name, formals, recursive, type = fun.params
        assert (len(formals) == len (rands))
        for i in range (len (rands)):
            arg = rands[i]
            formal = formals[i]
            if arg.is_a ('varref'):
                if self.assigned (arg) or self.assigned (formal):
                    complex.append (i)
                else:
                    simple.append (i)
            elif arg.is_a ('literal'):
                simple.append (i)
            # ok, this just fails with the t_stack.scm, because the field selection primapp
            #   hides the reference to an assigned variable.  think about how important this
            #   is and try to get it back?
            #elif len(formal.refs) == 1:
            #    # it's a complex arg, referred to only once.
            #    simple.append (i)
            # XXX because the case of field selection is so important (otherwise *every* vcase
            #     expression will allocate), I'm going to special case it here.
            elif len(formal.refs) == 1 and self.safe_nvget_inline (rands):
                simple.append (i)
            else:
                complex.append (i)
        if self.verbose:
            print 'inline: size=%3d name=%r simple=%r complex=%r calls=%d' % (fun.size, name, simple, complex, self.get_fun_calls (name))
        # substitute each simple arg in the body
        if simple:
            substs = [ (formals[i], rands[i]) for i in simple ]
        else:
            substs = []
        
        if not complex:
            result = self.substitute (body, substs)
        else:
            # generate new names for the complex args
            names = []
            inits = []
            for i in complex:
                # propagate types as well
                name = '%s_i%d' % (formals[i].name, analyzer.rename_counter)
                var = nodes.vardef (name)
                var.type = formals[i].type
                self.vars[name] = var
                names.append (var)
                inits.append (rands[i])
                varref = nodes.varref (names[-1].name)
                varref.type = names[-1].type
                varref.var = var
                substs.append ((formals[i], varref))
                analyzer.rename_counter += 1
            body = self.substitute (body, substs)
            result = nodes.let_splat (names, inits, body)
            result.type = body.type
        return result

    def substitute (self, body, substs):
        def replacer (node):
            # XXX consider this - set! will work when replacing with a
            #     variable, but what if it's a constant?  can this happen?
            if node.one_of ('varref', 'varset'):
                for k, v in substs:
                    if k.name == node.params:
                        # a match
                        if node.is_a ('varset'):
                            return nodes.varset (v.name, node.value)
                        else:
                            return v
                else:
                    return node
            else:
                return node
        return self.replace (body, replacer)

    inline_counter = 0

    def instantiate (self, fun):
        # give the body of a function, return a new copy with all fresh, unique
        #   bindings in order to preserve the alpha-converted state of the whole program.
        # first, get all fresh new nodes.  This is a somewhat simpler task than full
        #   alpha conversion - mostly because we know the bindings are already unique.
        fun = fun.deep_copy()
        body = fun.get_body()
        # now, append a unique modifier to every locally bound variable
        vars = []
        suffix = '_i%d' % (analyzer.inline_counter,)
        analyzer.inline_counter += 1

        def lookup_var (name, lenv):
            while lenv:
                rib, lenv = lenv
                for x in rib:
                    if x.name == name:
                        return x
            return False

        def rename (exp, lenv):
            if exp.binds():
                defs = exp.get_names()
                vars.extend (defs)
                lenv = (defs, lenv)
            elif exp.one_of ('varref', 'varset'):
                name = exp.params
                probe = lookup_var (name, lenv)
                if probe:
                    exp.params += suffix
                    # XXX is this a hack?  The reason we have this code here
                    #   is because the deep_copy() code doesn't copy the <assigns>
                    #   attribute of a vardef.  However, I'm not sure it *should*.
                    if exp.is_a ('varset'):
                        probe.assigns.append (exp)
            for sub in exp.subs:
                rename (sub, lenv)

        rename (body, None)
        # go back and rename all the vardefs
        for vd in vars:
            vd.name += suffix

        # add the new names to the global table
        for vd in vars:
            self.vars[vd.name] = vd

        return body
        
    def escape_analysis (self, root):
        
        # for each variable, we need to know if it might potentially
        #  escape.  a variable 'escapes' when it is referenced while free
        #  inside a function that escapes (i.e., any function that is
        #  varref'd outside of the operator position).

        escapes = set()

        def find_escaping_functions (node, parent):
            if node.is_a ('function'):
                # any function outside a fix (i.e., a lambda) is by
                #  definition an escaping one (because we reduce
                #    ((lambda () ...) ...)) => (let* ...)
                if not parent or not parent.is_a ('fix'):
                    escapes.add (node)
                    node.escapes = True
            elif node.is_a ('varref'):
                if not (parent.is_a ('application') and parent.subs[0] is node):
                    # function referenced in non-rator position
                    var = self.lookup_var (node)
                    if var.function:
                        var.function.escapes = True
                        escapes.add (var.function)
            for sub in node.subs:
                find_escaping_functions (sub, node)
                
        find_escaping_functions (root, None)

        # now we've found all the escaping functions.  now grep through them
        #   for escaping variables.  we do this by building an environment
        #   only below that function, anything that fails lookup is free.

        def lookup (name, lenv):
            while lenv:
                rib, lenv = lenv
                for v in rib:
                    if v == name:
                        return v
            return False

        def find_escaping_variables (node, lenv):
            if node.binds():
                names = [x.name for x in node.get_names()]
                lenv = (names, lenv)
            elif node.one_of ('varref', 'varset'):
                name = node.params
                if not lookup (name, lenv):
                    # reference to a free variable.  flag it as escaping.
                    var = self.lookup_var (node)
                    var.escapes = True
                    if self.verbose:
                        print '%r escapes' % (var,)
            for sub in node.subs:
                find_escaping_variables (sub, lenv)

        for fun in escapes:
            find_escaping_variables (fun, None)
            
            
    def find_leaves (self, exp):

        # descend the node tree, marking nodes as 'leaf' (or not) on the way up.

        def search (exp):
            if exp.is_a ('application') and not exp.recursive:
                # XXX need to do better here: I think what we want to distinguish are tail calls.
                is_leaf = False
            else:
                is_leaf = True
            for sub in exp.subs:
                sub.leaf = search (sub)
                is_leaf = is_leaf and sub.leaf
            return is_leaf

        exp.leaf = search (exp)

    def optimize_nvcase (self, root):

        # Sometimes the match compiler will output nvcase expressions
        #   for the same variable embedded inside each other.  Keep track
        #   of which alts have already been tested, and eliminate the leaf
        #   nvcase expressions when possible.  This has the nice side-effect
        #   of eliminating erroneous %%match-error calls.

        def lookup (name0, fat_env):
            result = []
            while fat_env:
                (name1, tags), fat_env = fat_env
                if name0 == name1:
                    result.extend (tags)
            return result

        def search (exp, fat_env):
            fatbar = False
            if exp.is_a ('primapp') and exp.name == '%%fatbar':
                fatbar = True
            elif exp.is_a ('nvcase') and exp.value.is_a ('varref'):
                # only trigger this when exp.val is a varref! [which cannot
                # happen with the match compiler, only a manual nvcase will
                # fail this test]
                # ok, which alts are examined at this level?
                name = exp.value.name
                dt = self.context.datatypes[exp.vtype]
                if len (exp.tags) < len (dt.tags):
                    # this nvcase is not exhaustive.  but have we already looked at the
                    # others?
                    already = lookup (name, fat_env)
                    if len(already) + len (exp.tags) == len (dt.tags):
                        # ok, with the upstream nvcases this one *is* exhaustive.
                        # so we can get rid of the else clause, and if there's only
                        # a single alt to test, we can get rid of the nvcase too, leaving
                        # only its alt body.
                        if len (exp.tags) == 1:
                            # remove the whole nvcase node.
                            assert (len (exp.subs) == 3)
                            assert (exp.subs[2].is_a ('primapp') and exp.subs[2].name in ('%%match-error', '%%fail'))
                            exp = exp.subs[1]
                            #print 'simplified nvcase completely'
                        else:
                            # just delete the %%match-error/%%fail
                            assert (exp.subs[-1].is_a ('primapp') and exp.subs[-1].name in ('%%match-error', '%%fail'))
                            exp.size -= exp.subs[-1].size
                            del exp.subs[-1]
                    else:
                        # ok, still not exhaustive.  extend fat_env with this new tag.
                        fat_env = ((name, exp.tags), fat_env)
                else:
                    # this nvcase is exhaustive.  no need to extend the fat_env
                    pass
            else:
                pass
            # fatbar is tricky here, because it can represent a sequence of tests,
            #   but in such a way that the test performed is *not* a direct ancestor
            #   of later tests... therefore when maintaining fat_env, we treat fatbar
            #   specially, and preserve the interior version of fat_env for the second
            #   test.  [theoretically this hack could be avoided if we had a variant of
            #   fatbar that correctly maintained the parent/child relationship between
            #   earlier and later tests... but this would require code duplication, the
            #   elimination of which is the whole *purpose* of fatbar]
            new_subs = []
            size = 1
            for sub in exp.subs:
                new_sub, fat_env2 = search (sub, fat_env)
                if fatbar:
                    # if we are in a fatbar, preserve the value of fat_env for the second branch.
                    fat_env = fat_env2
                new_subs.append (new_sub)
                size += new_sub.size
            exp.subs = new_subs
            exp.size = size
            return exp, fat_env
        
        root2, fat_env = search (root, None)
        return root2