# -*- Mode: Python; tab-width: 4 -*-

import grammar
import sys

# ribcage environment

class environment:
	def __init__ (self, symbols=None, values=None, extending=None):
		if symbols is None:
			symbols = []
		if values is None:
			values = []
		self.data = [symbols, values, extending]

	def extend (self, symbols, values):
		# return a new environment that extends this one
		return environment (symbols, values, self)

	def __getitem__ (self, key):
		[symbols, values, extending] = self.data
		if key in symbols:
			index = symbols.index (key)
			return values[index]
		elif extending is None:
			raise ValueError, "Unbound Variable: %s" % key
		else:
			return extending[key]
		
	def __setitem__ (self, key, value):
		[symbols, values, extending] = self.data
		if key in symbols:
			index = symbols.index (key)
			values[index] = value
		else:
			symbols.append (key)
			values.append (value)

class interpreter:

	grammar = grammar.make_grammar()

	def initial_environment (self):
		env = environment()
		prims = [
			'+', '-', '*',
			'add1', 'sub1',
			'minus', 'equal', 'greater', 'lesser', 'zero'
			]
		env = environment (
			prims,
			map (lambda x: ['prim-op', x], prims),
			None
			)
		return env

	def eval (self, exp, env):
		try:
			exp = self.grammar.DoParse1 (exp)
			try:
				return self.eval_exp (exp, env)
			except:
				import tb
				tb.printtb (sys.exc_traceback)
				print sys.exc_type, sys.exc_value
		except:
			print 'Parse Error'

	def print_value (self, value):
		if type(value) == type ([]):
			return '<%s at %x>' % (value[0], id(value))
		elif value is None:
			return '<undefined>'
		else:
			return repr(value)

	def read_eval_print (self, env=None):
		if env is None:
			env = self.initial_environment()
		while 1:
			sys.stdout.write ('--> ')
			sys.stdout.flush()

			try:
				exp = raw_input()
			except EOFError:
				print
				break

			if exp:
				print self.print_value (self.eval (exp, env))

	repl = read_eval_print

	def eval_exp (self, exp, env):
		if exp[0] == 'lit':
			return exp[1]
		elif exp[0] == 'varref':
			return env[exp[1]]
		elif exp[0] == 'app':
			proc = self.eval_exp (exp[1], env)
			args = self.eval_rands (exp[2], env)
			return self.apply_proc (proc, args)
		elif exp[0] == 'conditional':
			test = self.eval_exp (exp[1], env)
			if test != 0:
				return self.eval_exp (exp[2], env)
			else:
				return self.eval_exp (exp[3], env)
		elif exp[0] == 'let':
			symbols = []
			values = []
			for [ignore,var,decl] in exp[1]:
				symbols.append (var)
				values.append (self.eval_exp (decl, env))
			new_env = env.extend (symbols, values)
			# evaluate the expression in the new environment
			return self.eval_exp (exp[2], new_env)
		elif exp[0] == 'proc':
			# procedure definition simply creates a closure:
			# ['closure' <formals> <body> <env>]
			return ['closure', exp[1], exp[2], env]
		elif exp[0] == 'varassign':
			env[exp[1]] = self.eval_exp (exp[2], env)
		elif exp[0] == 'compound':
			for sub in exp[1]:
				result = self.eval_exp (sub, env)
			return result
		else:
			raise SyntaxError, "Invalid Abstract Syntax"

	def eval_rands (self, rands, env):
		return map (
			lambda rand,e=env,s=self: s.eval_exp(rand,e),
			rands
			)

	def apply_proc (self, proc, args):
		if proc[0] == 'prim-op':
			return self.apply_prim_op (proc[1], args)
		elif proc[0] == 'closure':
			# proc := formals body environment
			[ignore, formals, body, env] = proc
			return self.eval_exp (
				body,
				env.extend (formals, args)
				)
		else:
			raise ValueError, "Invalid Procedure"

	def apply_prim_op (self, prim_op, args):
		if prim_op == '+':
			return args[0] + args[1]
		elif prim_op == '-':
			return args[0] - args[1]
		elif prim_op == '*':
			return args[0] * args[1]
		elif prim_op == 'add1':
			return args[0] + 1
		elif prim_op == 'sub1':
			return args[0] - 1
		elif prim_op == 'minus':
			return - args[0]
		elif prim_op == 'equal':
			return args[0] == args[1]
		elif prim_op == 'greater':
			return args[0] > args[1]
		elif prim_op == 'lesser':
			return args[0] < args[1]
		elif prim_op == 'zero':
			return args[0] == 0
		else:
			raise "Invalid primitive operator"

# We can now do recursion, even without a 'letrec', using the 'Y combinator'.
#
# Here's the applicative-order Y combinator:
#
#   y := proc (f) (proc (x) f(proc(y) (x(x))(y)))(proc (x) f(proc(y) (x(x))(y)))
#
# and a factorial function defined using it:
#
#   fact := y (proc (g) proc (n) if zero(n) then 1 else *(n,g(-(n,1))))

if __name__ == '__main__':
	i = interpreter()
	e = i.initial_environment()
	i.eval ('Y := proc (f) (proc (x) f(proc(y) (x(x))(y)))(proc (x) f(proc(y) (x(x))(y)))', e)
	print 'The Y combinator has been defined; try this:'
	print '  fact := Y (proc (g) proc (n) if zero(n) then 1 else *(n,g(-(n,1))))'
	i.repl (e)
