# -*- Mode: Python; tab-width: 4 -*-

import grammar
import sys

# ribcage environment

class environment:
	def __init__ (self, symbols=None, values=None, extending=None):
		if symbols is None:
			symbols = []
		if values is None:
			values = []
		self.data = [symbols, values, extending]

	def extend (self, symbols, values):
		# return a new environment that extends this one
		return environment (symbols, values, self)

	def __getitem__ (self, key):
		[symbols, values, extending] = self.data
		if key in symbols:
			index = symbols.index (key)
			return values[index]
		elif extending is None:
			raise ValueError, "Unbound Variable: %s" % key
		else:
			return extending[key]
		
	def __setitem__ (self, key, value):
		[symbols, values, extending] = self.data
		if key in symbols:
			index = symbols.index (key)
			values[index] = value
		elif extending is None:
			raise ValueError, "Unbound Variable: %s" % key
		else:
			extending[key] = value

	def define (self, key, value):
		[symbols, values, extending] = self.data
		symbols.append (key)
		values.append (value)

class symbol_generator:
	def __init__ (self):
		self.count = 0

	def next (self):
		r = '_g%d' % self.count
		self.count = self.count + 1
		return r


# These transformations could be added directly to the parser, but I
# prefer this: everything's modular: I can make another transformer
# (for say, CPS) and plug it into the output of this one.

class transformer:

	# ========================================
	# Program Transformation
	# ========================================

	# these are pretty straightforward translations of
	# the expansions in section 7.3 of R^4RS.

	# The primitive expression types are:
	#
	#   literal, variable, call, lambda, if, set!
	#
	# everything else can be constructed from these
	# [and call/cc is a primitive procedure!]

	gensym = symbol_generator()

	# walk an expression, recursively transforming derived
	# expressions into primitive ones...

	def expand_expression (self, exp):
		if exp[0] in ['lit', 'varref']:
			return exp
		elif exp[0] in ['varassign', 'definition']:
			return [exp[0], exp[1], self.expand_expression (exp[2])]
		elif exp[0] == 'app':
			# ['app' <operator_exp> [<operand_exp> ... ]]
			return ['app', self.expand_expression (exp[1]),
					map (
						self.expand_expression,
						exp[2]
						)
					]
		elif exp[0] == 'conditional':
			# ['conditional' <test_exp> <then_exp> <else_exp>]
			return ['conditional'] + map (
				self.expand_expression,
				exp[1:]
				)
		elif exp[0] == 'proc':
			# ['proc' <vars> <exp>]
			return ['proc', exp[1], self.expand_expression (exp[2])]
		elif exp[0] == 'let':
			return self.expand_let (exp)
		elif exp[0] == 'letrecproc':
			return self.expand_letrecproc (exp)
		elif exp[0] == 'compound':
			return self.expand_compound (exp)
		else:
			raise ValueError, "Unknown Expression Type: %s" % exp

	def expand_let (self, exp):
		# ['let ' [ ['decl' <var_1> <init_1>] ...] <body> ]
		# =>
		#  ['app' ['proc', ['var_1' ... 'var_n'], <body> ] [ <init_1> ... <init_n> ]]
		#
		# <init_1..n> and <body> are sub-expressions and must be expanded.
		#
		symbols = map (lambda x: x[1], exp[1])
		inits   = map (self.expand_expression, map (lambda x: x[2], exp[1]))
		body    = self.expand_expression (exp[2])
		return ['app', ['proc', symbols, body], inits]

	def expand_compound (self, exp):
		# ['compound' [<exp_1> ... <exp_n>]]
		# =>
		# ['app'
		#    ['proc' ['ignore' 'thunk'] ['app' ['varref' 'thunk'] [] ] ]
		#    <exp_1>
		#    ['proc' [] ['compound' [<exp_2> ... <exp_n>]]]]
		# 
		if len(exp[1]) == 1:
			return self.expand_expression (exp[1][0])
		else:
			return self.expand_expression (
				['app',
				 ['proc', ['ignore', 'thunk'], ['app', ['varref', 'thunk'], [] ]],
				 [exp[1][0],  # <exp_1>
				  ['proc', [], ['compound', exp[1][1:]]] # <exp_2> ... <exp_n>
				  ]
				 ]
				)
				
	def expand_letrecproc (self, exp):
		# letrecproc ::= [letrecproc procdecls body]
		# procdecls ::= [ [var [varlist] exp] ... ]
		#
		# letrecproc
		#    p1 (f1) = b1;
		#    ...
		#    pn (fn) = bn;
		# in
		#    <body>
		# ==>
		# let
		#    p1 = None
		#    ...
		#    pn = None
		# in
		#   begin
		#     let
		#       g1 = proc (f1) b1
		#       ...
		#       gn = proc (fn) bn
		#     in
		#       p1 := g1
		#       ...
		#       pn := gn
		#   <body>
		#   end
		symbols = map (lambda x: x[1], exp[1])
		formals = map (lambda x: x[2], exp[1])
		bodies  = map (lambda x: x[3], exp[1])
		gensyms = map (lambda x,s=self: s.gensym.next(), range(len(symbols)))
		body = exp[2]
		
		empty_decls = map (lambda s: ['decl', s, ['lit', 0]], symbols)
		proc_decls = map (
			lambda g,f,b: ['decl', g, ['proc', f, b]],
			gensyms, formals, bodies
			)
		assignments = map (
			lambda s,g: ['varassign', s, ['varref', g]],
			symbols, gensyms
			)
		return self.expand_expression (
			['let', empty_decls,
			 ['compound', [
					['let', proc_decls,
					 ['compound', assignments],
					 ],
					body
					]
			  ]
			 ]
			)

class interpreter:

	grammar = grammar.make_grammar()
	transform   = transformer()

	def initial_environment (self):
		env = environment()
		prim_names = [
			'+', '-', '*',
			'add1', 'sub1',
			'minus', 'equal', 'greater', 'lesser', 'zero', 
			'print'
			]
		prim_procs =  map (lambda x: ['prim-op', x], prim_names)

		# ==============================
		# add the magikal call/cc
		# ==============================
		# [it can't be added as a true primitive because apply_proc()
		#  wraps the call with apply-continuation()...]

		prim_names.append ('callcc')
		prim_procs.append (['callcc-proc'])

		env = environment (prim_names, prim_procs, None)
		return env

	def eval (self, exp, env):

		# ====================
		# parse
		# ====================

		try:
			exp = self.grammar.DoParse1 (exp)

		except:
			print 'Parse Error'
			import tb
			tb.printtb (sys.exc_traceback)
			print sys.exc_type, sys.exc_value
			return None

		# ====================
		# expand
		# ====================

		try:
			exp = self.transform.expand_expression (exp)
			print 'expanded to:'
			import pprint
			pprint.pprint (exp)
			print '=>'

		except:
			print 'Expansion Error'
			import tb
			tb.printtb (sys.exc_traceback)
			print sys.exc_type, sys.exc_value
			return

		# ====================
		# evaluate
		# ====================

		try:
			return self.eval_exp (exp, env, ['final-valcont'])
		except:
			print 'Evaluation Error'
			import tb
			tb.printtb (sys.exc_traceback)
			print sys.exc_type, sys.exc_value

	def print_value (self, value):
		if type(value) == type ([]):
			return '<%s at %x>' % (value[0], id(value))
		elif value is None:
			return '<undefined>'
		else:
			return repr(value)

	def read_eval_print (self, env=None):
		if env is None:
			env = self.initial_environment()
		while 1:
			sys.stdout.write ('--> ')
			sys.stdout.flush()

			try:
				exp = raw_input()
			except EOFError:
				print
				break

			if exp:
				print self.print_value (self.eval (exp, env))

	repl = read_eval_print

	def eval_exp (self, exp, env, k):
		if exp[0] == 'lit':
			return self.apply_continuation (k, exp[1])
		elif exp[0] == 'varref':
			return self.apply_continuation (k, env[exp[1]])
		elif exp[0] == 'app':
			return self.eval_exp (
				exp[1],
				env,
				['proc-valcont', exp[2], env, k]
				)
		elif exp[0] == 'conditional':
			return self.eval_exp (
				exp[1],
				env,
				['test-valcont', exp[2], exp[3], env, k]
				)
		elif exp[0] == 'proc':
			# procedure definition simply creates a closure:
			# ['closure' <formals> <body> <env>]
			return self.apply_continuation (
				k,
				['closure', exp[1], exp[2], env],
				)
		elif exp[0] == 'varassign':
			return self.apply_continuation (
				k,
				self.eval_exp (
					exp[2],
					env,
					['assign-valcont', exp[1], env, 0]
					)
				)
		elif exp[0] == 'definition':
			return self.apply_continuation (
				k,
				self.eval_exp (
					exp[2],
					env,
					['assign-valcont', exp[1], env, 1]
					)
				)
		else:
			raise SyntaxError, "Invalid Abstract Syntax"


	# ========================================
	# Application
	# ========================================

	def eval_rands (self, rands, env, k):
		if rands == []:
			return self.apply_continuation (k, [])
		else:
			return self.eval_exp (
				rands[0],
				env,
				['first-valcont', rands, env, k]
				)

	def apply_proc (self, proc, args, k):
		if proc[0] == 'prim-op':
			return self.apply_continuation (
				k,
				self.apply_prim_op (proc[1], args)
				)
		elif proc[0] == 'closure':
			# proc := formals body environment
			[ignore, formals, body, env] = proc
			return self.eval_exp (
				body,
				env.extend (formals, args),
				k
				)
		elif proc[0] == 'continuation':
			print 'applying continuation', proc[1], ' to', args[0]
			return self.apply_continuation (
				proc[1], args[0]
				)
		elif proc[0] == 'callcc-proc':
			return self.apply_proc (
				args[0],
				[['continuation', k]],
				k
				)
		else:
			raise ValueError, "Invalid Procedure"

	def apply_prim_op (self, prim_op, args):
		if prim_op == '+':
			return args[0] + args[1]
		elif prim_op == '-':
			return args[0] - args[1]
		elif prim_op == '*':
			return args[0] * args[1]
		elif prim_op == 'add1':
			return args[0] + 1
		elif prim_op == 'sub1':
			return args[0] - 1
		elif prim_op == 'minus':
			return - args[0]
		elif prim_op == 'equal':
			return args[0] == args[1]
		elif prim_op == 'greater':
			return args[0] > args[1]
		elif prim_op == 'lesser':
			return args[0] < args[1]
		elif prim_op == 'zero':
			return args[0] == 0
		elif prim_op == 'print':
			print args
		else:
			raise "Invalid primitive operator"

	def apply_continuation (self, k, val):
		if k[0] == 'final-valcont':
			return val
		elif k[0] == 'proc-valcont':
			return self.eval_rands (
				k[1], k[2],
				['all-argcont', val, k[3]]
				)
		elif k[0] == 'all-argcont':
			return self.apply_proc (
				k[1], val, k[2]
				)
		elif k[0] == 'test-valcont':
			if val:
				return self.eval_exp (k[1], k[3], k[4])
			else:
				return self.eval_exp (k[2], k[3], k[4])
		elif k[0] == 'first-valcont':
			return self.eval_rands (
				k[1][1:], # (cdr rands)
				k[2], # env
				['rest-argcont', val, k[3]]
				)
		elif k[0] == 'rest-argcont':
			return self.apply_continuation (
				k[2], # k
				[k[1]] + val, # (cons first rest)
				)
		elif k[0] == 'assign-valcont':
			[ignore, var, env, bind] = k
			try:
				env[var] = val
			except ValueError:
				if bind:
					env.define (var, val)
				else:
					raise ValueError, "Unbound Variable: %s" % var
			# no return value
		else:
			raise ValueError, "Unknown Continuation Type: %s" % k
			

if __name__ == '__main__':
	i = interpreter()
	print """This interpreter has call/cc.  Try the following sequence:
	define escape = 0
	define savecont = proc(x) escape := x
	define p = proc (x) begin print(x); if equal(x,10) then callcc(savecont) else 0; if equal(x,0) then 1 else p(sub1(x)) end
	
	<escape> will hold a continuation
	<savecont> will save a continuation into <escape>
	<p> is a simple recursive function that will count from <x> to 0,
	 pausing at 10 to save a copy of the current continuation.'

	This continuation can be invoked later with a dummy argument"""
	i.repl()
