# -*- Mode: Python; tab-width: 4 -*-

import grammar
import sys

# ribcage environment

class environment:
	def __init__ (self, symbols=None, values=None, extending=None):
		if symbols is None:
			symbols = []
		if values is None:
			values = []
		self.data = [symbols, values, extending]

	def extend (self, symbols, values):
		# return a new environment that extends this one
		return environment (symbols, values, self)

	def __getitem__ (self, key):
		[symbols, values, extending] = self.data
		if key in symbols:
			index = symbols.index (key)
			return values[index]
		elif extending is None:
			raise ValueError, "Unbound Variable: %s" % key
		else:
			return extending[key]
		
	def __setitem__ (self, key, value):
		[symbols, values, extending] = self.data
		if key in symbols:
			index = symbols.index (key)
			values[index] = value
		elif extending is None:
			raise ValueError, "Unbound Variable: %s" % key
		else:
			extending[key] = value

class symbol_generator:
	def __init__ (self):
		self.count = 0

	def next (self):
		r = '_g%d' % self.count
		self.count = self.count + 1
		return r


class interpreter:

	grammar = grammar.make_grammar()

	def initial_environment (self):
		env = environment()
		prims = [
			'+', '-', '*',
			'add1', 'sub1',
			'minus', 'equal', 'greater', 'lesser', 'zero'
			]
		env = environment (
			prims,
			map (lambda x: ['prim-op', x], prims),
			None
			)
		return env

	def eval (self, exp, env):
		try:
			exp = self.grammar.DoParse1 (exp)
			try:
				return self.eval_exp (exp, env)
			except:
				import tb
				tb.printtb (sys.exc_traceback)
				print sys.exc_type, sys.exc_value
		except:
			print 'Parse Error'

	def print_value (self, value):
		if type(value) == type ([]):
			return '<%s at %x>' % (value[0], id(value))
		elif value is None:
			return '<undefined>'
		else:
			return repr(value)

	def read_eval_print (self, env=None):
		if env is None:
			env = self.initial_environment()
		while 1:
			sys.stdout.write ('--> ')
			sys.stdout.flush()

			try:
				exp = raw_input()
			except EOFError:
				print
				break

			if exp:
				print self.print_value (self.eval (exp, env))

	repl = read_eval_print

	def eval_exp (self, exp, env):
		if exp[0] == 'lit':
			return exp[1]
		elif exp[0] == 'varref':
			return env[exp[1]]
		elif exp[0] == 'app':
			proc = self.eval_exp (exp[1], env)
			args = self.eval_rands (exp[2], env)
			return self.apply_proc (proc, args)
		elif exp[0] == 'conditional':
			test = self.eval_exp (exp[1], env)
			if test != 0:
				return self.eval_exp (exp[2], env)
			else:
				return self.eval_exp (exp[3], env)
		elif exp[0] == 'let':
			symbols = []
			values = []
			for [ignore,var,decl] in exp[1]:
				symbols.append (var)
				values.append (self.eval_exp (decl, env))
			new_env = env.extend (symbols, values)
			# evaluate the expression in the new environment
			return self.eval_exp (exp[2], new_env)
		elif exp[0] == 'proc':
			# procedure definition simply creates a closure:
			# ['closure' <formals> <body> <env>]
			return ['closure', exp[1], exp[2], env]
		elif exp[0] == 'varassign':
			env[exp[1]] = self.eval_exp (exp[2], env)
		elif exp[0] == 'compound':
			for sub in exp[1]:
				result = self.eval_exp (sub, env)
			return result
		elif exp[0] == 'letrecproc':
			return self.eval_exp (
				self.expand_letrecproc (exp),
				env
				)
		else:
			raise SyntaxError, "Invalid Abstract Syntax"

	# ========================================
	# Program Transformation
	# ========================================

	gensym = symbol_generator()

	def expand_letrecproc (self, exp):
		# letrecproc ::= [letrecproc procdecls body]
		# procdecls ::= [ [var [varlist] exp] ... ]
		#
		# letrecproc
		#    p1 (f1) = b1;
		#    ...
		#    pn (fn) = bn;
		# in
		#    <body>
		# ==>
		# let
		#    p1 = None
		#    ...
		#    pn = None
		# in
		#   begin
		#     let
		#       g1 = proc (f1) b1
		#       ...
		#       gn = proc (fn) bn
		#     in
		#       p1 := g1
		#       ...
		#       pn := gn
		#   <body>
		#   end
		symbols = map (lambda x: x[1], exp[1])
		formals = map (lambda x: x[2], exp[1])
		bodies  = map (lambda x: x[3], exp[1])
		gensyms = map (lambda x,s=self: s.gensym.next(), range(len(symbols)))
		body = exp[2]
		
		empty_decls = map (lambda s: ['decl', s, ['lit', 0]], symbols)
		proc_decls = map (
			lambda g,f,b: ['decl', g, ['proc', f, b]],
			gensyms, formals, bodies
			)
		assignments = map (
			lambda s,g: ['varassign', s, ['varref', g]],
			symbols, gensyms
			)
		return ['let', empty_decls,
				['compound', [
					['let', proc_decls,
					 ['compound', assignments],
					 ],
					body
					]
				 ]
				]

	# ========================================
	# Application
	# ========================================

	def eval_rands (self, rands, env):
		return map (
			lambda rand,e=env,s=self: s.eval_exp(rand,e),
			rands
			)

	def apply_proc (self, proc, args):
		if proc[0] == 'prim-op':
			return self.apply_prim_op (proc[1], args)
		elif proc[0] == 'closure':
			# proc := formals body environment
			[ignore, formals, body, env] = proc
			return self.eval_exp (
				body,
				env.extend (formals, args)
				)
		else:
			raise ValueError, "Invalid Procedure"

	def apply_prim_op (self, prim_op, args):
		if prim_op == '+':
			return args[0] + args[1]
		elif prim_op == '-':
			return args[0] - args[1]
		elif prim_op == '*':
			return args[0] * args[1]
		elif prim_op == 'add1':
			return args[0] + 1
		elif prim_op == 'sub1':
			return args[0] - 1
		elif prim_op == 'minus':
			return - args[0]
		elif prim_op == 'equal':
			return args[0] == args[1]
		elif prim_op == 'greater':
			return args[0] > args[1]
		elif prim_op == 'lesser':
			return args[0] < args[1]
		elif prim_op == 'zero':
			return args[0] == 0
		else:
			raise "Invalid primitive operator"

if __name__ == '__main__':
	i = interpreter()
	i.repl()
