# -*- Mode: Python; tab-width: 4 -*-

import grammar
import string
import sys

# ribcage environment

class environment:
	counter = 0
	def __init__ (self, symbols=None, values=None, extending=None):
		if symbols is None:
			symbols = []
		if values is None:
			values = []
		self.data = [symbols, values, extending]
		environment.counter = environment.counter + 1

	def extend (self, symbols, values):
		# return a new environment that extends this one
		return environment (symbols, values, self)

	def __getitem__ (self, key):
		[symbols, values, extending] = self.data
		if key in symbols:
			index = symbols.index (key)
			return values[index]
		elif extending is None:
			raise ValueError, "Unbound Variable: %s" % key
		else:
			return extending[key]
		
	def __setitem__ (self, key, value):
		[symbols, values, extending] = self.data
		if key in symbols:
			index = symbols.index (key)
			values[index] = value
		elif extending is None:
			raise ValueError, "Unbound Variable: %s" % key
		else:
			extending[key] = value

	def define (self, key, value):
		[symbols, values, extending] = self.data
		symbols.append (key)
		values.append (value)


class primitive_environment (environment):

	def __getitem__ (self, key):
		import __builtin__
		return ['python_builtin', getattr (__builtin__, key)]

	def __setitem__ (self, key, value):
		raise SystemError, "attempt to assign to primitive environment"

	define = __setitem__

class symbol_generator:
	def __init__ (self):
		self.count = 0

	def next (self):
		r = '_g%d' % self.count
		self.count = self.count + 1
		return r


# These transformations could be added directly to the parser, but I
# prefer this: everything's modular: I can make another transformer
# (for say, CPS) and plug it into the output of this one.

class transformer:

	# ========================================
	# Program Transformation
	# ========================================

	# these are pretty straightforward translations of
	# the expansions in section 7.3 of R^4RS.

	# The primitive expression types are:
	#
	#   literal, variable, call, lambda, if, set!
	#
	# everything else can be constructed from these
	# [and call/cc is a primitive procedure!]

	gensym = symbol_generator()

	# walk an expression, recursively transforming derived
	# expressions into primitive ones...

	def expand_expression (self, exp):
		if exp[0] in ['lit', 'varref']:
			return exp
		elif exp[0] in ['varassign', 'definition']:
			return [exp[0], exp[1], self.expand_expression (exp[2])]
		elif exp[0] == 'app':
			# ['app' <operator_exp> [<operand_exp> ... ]]
			return ['app', self.expand_expression (exp[1]),
					map (
						self.expand_expression,
						exp[2]
						)
					]
		elif exp[0] == 'conditional':
			# ['conditional' <test_exp> <then_exp> <else_exp>]
			return ['conditional'] + map (
				self.expand_expression,
				exp[1:]
				)
		elif exp[0] == 'proc':
			# ['proc' <vars> <exp>]
			return ['proc', exp[1], self.expand_expression (exp[2])]
		elif exp[0] == 'let':
			return self.expand_let (exp)
		elif exp[0] == 'letrecproc':
			return self.expand_letrecproc (exp)
		elif exp[0] == 'letcont':
			return self.expand_letcont (exp)
		elif exp[0] == 'compound':
			return self.expand_compound (exp)
		elif exp[0] == 'while':
			return self.expand_while (exp)
		elif exp[0] == 'list':
			return ['app', ['varref', '__list__'], map (self.expand_expression, exp[1])]
		elif exp[0] == 'array_ref':
			return ['app',
					['varref', '__array_ref__'],
					[self.expand_expression (exp[1]),
					 self.expand_expression (exp[2])
					 ]
					]
		else:
			raise ValueError, "Unknown Expression Type: %s" % exp

	def expand_let (self, exp):
		# ['let ' [ ['decl' <var_1> <init_1>] ...] <body> ]
		# =>
		#  ['app' ['proc', ['var_1' ... 'var_n'], <body> ] [ <init_1> ... <init_n> ]]
		#
		# <init_1..n> and <body> are sub-expressions and must be expanded.
		#
		symbols = map (lambda x: x[1], exp[1])
		inits   = map (self.expand_expression, map (lambda x: x[2], exp[1]))
		body    = self.expand_expression (exp[2])
		return ['app', ['proc', symbols, body], inits]

	def expand_letcont (self, exp):
		# letcont <var> in <body>
		# =>
		#  callcc (proc (<var>) <body>)
		return ['app',
				['varref', 'callcc'],
				[['proc',
				 [exp[1]],
				 self.expand_expression (exp[3])
				 ]]
				]

	def expand_compound (self, exp):
		# ['compound' [<exp_1> ... <exp_n>]]
		# =>
		# ['app'
		#    ['proc' ['ignore' 'thunk'] ['app' ['varref' 'thunk'] [] ] ]
		#    <exp_1>
		#    ['proc' [] ['compound' [<exp_2> ... <exp_n>]]]]
		# 
		if len(exp[1]) == 1:
			return self.expand_expression (exp[1][0])
		else:
			return self.expand_expression (
				['app',
				 ['proc', ['ignore', 'thunk'], ['app', ['varref', 'thunk'], [] ]],
				 [exp[1][0],  # <exp_1>
				  ['proc', [], ['compound', exp[1][1:]]] # <exp_2> ... <exp_n>
				  ]
				 ]
				)
				

	def expand_while (self, exp):
		# while <test_exp> do <body_exp>
		# =>
		# letrecproc
		#   <gensym> = proc()
		#     if <test_exp>
		#       then
		#         begin
		#           <body_exp>
		#           <gensym>()
		#         end
		#       else
		#         0
		#   in <gensym>()
		#
		gen_proc_name = self.gensym.next()
		return self.expand_letrecproc (
			['letrecproc',
			 [['procdecl', gen_proc_name, [],
			   ['conditional',			# if
				exp[1],					# <text_exp>
				['compound',			# begin
				 [exp[3],				# <body_exp>
				  ['app', ['varref', gen_proc_name], []]] # <gensym>()
				 ],
				['lit', 0]
				]
			   ]
			  ],
			 ['app', ['varref', gen_proc_name], []]] # <gensym>()
			)
			 

	def expand_letrecproc (self, exp):
		# *** second let isn't necessary, remove it
		# 
		# letrecproc ::= [letrecproc procdecls body]
		# procdecls ::= [ [var [varlist] exp] ... ]
		#
		# letrecproc
		#    p1 (f1) = b1;
		#    ...
		#    pn (fn) = bn;
		# in
		#    <body>
		# ==>
		# let
		#    p1 = None
		#    ...
		#    pn = None
		# in
		#   begin
		#     let
		#       g1 = proc (f1) b1
		#       ...
		#       gn = proc (fn) bn
		#     in
		#       p1 := g1
		#       ...
		#       pn := gn
		#   <body>
		#   end
		symbols = map (lambda x: x[1], exp[1])
		formals = map (lambda x: x[2], exp[1])
		bodies  = map (lambda x: x[3], exp[1])
		gensyms = map (lambda x,s=self: s.gensym.next(), range(len(symbols)))
		body = exp[2]
		
		empty_decls = map (lambda s: ['decl', s, ['lit', 0]], symbols)
		proc_decls = map (
			lambda g,f,b: ['decl', g, ['proc', f, b]],
			gensyms, formals, bodies
			)
		assignments = map (
			lambda s,g: ['varassign', s, ['varref', g]],
			symbols, gensyms
			)
		return self.expand_expression (
			['let', empty_decls,
			 ['compound', [
					['let', proc_decls,
					 ['compound', assignments],
					 ],
					body
					]
			  ]
			 ]
			)


class interpreter:

	grammar = grammar.make_grammar()
	transform   = transformer()

	def initial_environment (self):
		prim_names = [
			'+', '-', '*',
			'add1', 'sub1',
			'minus', 'equal', 'greater', 'lesser', 'zero', 
			'print', 'error', '__list__', '__array_ref__'
			]
		prim_procs =  map (lambda x: ['prim-op', x], prim_names)

		# ==============================
		# add the magikal call/cc
		# ==============================
		# [it can't be added as a true primitive because apply_proc()
		#  wraps the call with apply-continuation()...]

		prim_names.append ('callcc')
		prim_procs.append (['callcc-proc'])

		env = environment (
			prim_names,
			prim_procs,
			primitive_environment() # give access to python builtins
			)
		return env

	def eval (self, exp, env):

		# ====================
		# parse
		# ====================

		try:
			exp = self.grammar.DoParse1 (exp)

		except:
			print 'Parse Error'
			import tb
			tb.printtb (sys.exc_traceback)
			print sys.exc_type, sys.exc_value
			return None

		# ====================
		# expand
		# ====================

		try:
			exp = self.transform.expand_expression (exp)
			print 'expanded to:'
			import pprint
			pprint.pprint (exp)
			print '=>'

		except:
			print 'Expansion Error'
			import tb
			tb.printtb (sys.exc_traceback)
			print sys.exc_type, sys.exc_value
			return

		# ====================
		# evaluate
		# ====================

		try:
			return self.eval_exp (exp, env, ['final-valcont'])
		except:
			print 'Evaluation Error'
			import tb
			tb.printtb (sys.exc_traceback)
			print sys.exc_type, sys.exc_value

	def print_value (self, value):
		#if type(value) == type ([]):
		#	return '<%s at %x>' % (value[0], id(value))
		if value is None:
			return '<undefined>'
		else:
			return repr(value)

	def read_eval_print (self, env=None):
		if env is None:
			env = self.initial_environment()
		while 1:
			sys.stdout.write ('--> ')
			sys.stdout.flush()

			try:
				exp = raw_input()
				if exp and exp[0] == '-':
					lines = [exp[1:]]
					# multi-line input
					while 1:
						sys.stdout.write ('... ')
						sys.stdout.flush()
						line = raw_input()
						if not line:
							break
						else:
							lines.append (line)
					exp = string.join (lines, ' ')
					
			except EOFError:
				print
				break

			if exp:
				environment.counter = 0
				print self.print_value (self.eval (exp, env))
				print '[%d environments]' % environment.counter

	repl = read_eval_print

	def eval_exp (self, exp, env, k):
		if exp[0] == 'lit':
			return self.apply_continuation (k, exp[1])
		elif exp[0] == 'varref':
			return self.apply_continuation (k, env[exp[1]])
		elif exp[0] == 'app':
			return self.eval_exp (
				exp[1],
				env,
				['proc-valcont', exp[2], env, k]
				)
		elif exp[0] == 'conditional':
			return self.eval_exp (
				exp[1],
				env,
				['test-valcont', exp[2], exp[3], env, k]
				)
		elif exp[0] == 'proc':
			# procedure definition simply creates a closure:
			# ['closure' <formals> <body> <env>]
			return self.apply_continuation (
				k,
				['closure', exp[1], exp[2], env],
				)
		elif exp[0] == 'varassign':
			return self.eval_exp (
				exp[2],
				env,
				['assign-valcont', exp[1], env, 0, k]
				)
		elif exp[0] == 'definition':
			return self.eval_exp (
				exp[2],
				env,
				['assign-valcont', exp[1], env, 1, k]
				)
		else:
			raise SyntaxError, "Invalid Abstract Syntax"


	# ========================================
	# Application
	# ========================================

	def eval_rands (self, rands, env, k):
		if rands == []:
			return self.apply_continuation (k, [])
		else:
			return self.eval_exp (
				rands[0],
				env,
				['first-valcont', rands, env, k]
				)

	def apply_proc (self, proc, args, k):
		if proc[0] == 'prim-op':
			return self.apply_continuation (
				k,
				self.apply_prim_op (proc[1], args)
				)
		if proc[0] == 'python_builtin':
			# python builtin functions
			return apply (proc[1], tuple(args))
		elif proc[0] == 'closure':
			# proc := formals body environment
			[ignore, formals, body, env] = proc
			return self.eval_exp (
				body,
				env.extend (formals, args),
				k
				)
		elif proc[0] == 'continuation':
			return self.apply_continuation (
				proc[1], args[0]
				)
		elif proc[0] == 'callcc-proc':
			return self.apply_proc (
				args[0],
				[['continuation', k]],
				k
				)
		else:
			raise ValueError, "Invalid Procedure"

	def apply_prim_op (self, prim_op, args):
		if prim_op == '+':
			return args[0] + args[1]
		elif prim_op == '-':
			return args[0] - args[1]
		elif prim_op == '*':
			return args[0] * args[1]
		elif prim_op == 'add1':
			return args[0] + 1
		elif prim_op == 'sub1':
			return args[0] - 1
		elif prim_op == 'minus':
			return - args[0]
		elif prim_op == 'equal':
			return args[0] == args[1]
		elif prim_op == 'greater':
			return args[0] > args[1]
		elif prim_op == 'lesser':
			return args[0] < args[1]
		elif prim_op == 'zero':
			return args[0] == 0
		elif prim_op == 'print':
			print args
		elif prim_op == 'error':
			raise 'Error', 'user error raised'
		elif prim_op == '__list__':
			return args
		elif prim_op == '__array_ref__':
			return (args[0])[args[1]]
		else:
			raise "Invalid primitive operator"

	def apply_continuation (self, k, val):
		if k[0] == 'final-valcont':
			return val
		elif k[0] == 'proc-valcont':
			[ignore, rands, env, k] = k
			return self.eval_rands (
				rands,
				env,
				['all-argcont', val, k]
				)
		elif k[0] == 'all-argcont':
			[ignore, proc, k] = k
			return self.apply_proc (proc, val, k)
		elif k[0] == 'test-valcont':
			[ignore, then_exp, else_exp, env, k] = k
			if val:
				return self.eval_exp (then_exp, env, k)
			else:
				return self.eval_exp (else_exp, env, k)
		elif k[0] == 'first-valcont':
			[ignore, rands, env, k] = k
			return self.eval_rands (
				rands[1:], # (cdr rands)
				env,
				['rest-argcont', val, k]
				)
		elif k[0] == 'rest-argcont':
			[ignore, first, k] = k
			return self.apply_continuation (
				k,
				[first] + val, # (cons first rest)
				)
		elif k[0] == 'assign-valcont':
			[ignore, var, env, bind, k] = k
			if bind:
				env.define (var, val)
			else:
				try:
					env[var] = val
				except ValueError:
					raise ValueError, "Unbound Variable: %s" % var
			return self.apply_continuation (
				k,
				None
				)
		else:
			raise ValueError, "Unknown Continuation Type: %s" % k
			

if __name__ == '__main__':
	i = interpreter()
	print """This interpreter has call/cc.  Try the following sequence:
	define escape = 0
	define savecont = proc(x) escape := x
	define p = proc (x) begin print(x); if equal(x,10) then callcc(savecont) else 0; if equal(x,0) then 1 else p(sub1(x)) end
	
	<escape> will hold a continuation
	<savecont> will save a continuation into <escape>
	<p> is a simple recursive function that will count from <x> to 0,
	 pausing at 10 to save a copy of the current continuation.'

	This continuation can be invoked later with a dummy argument"""
	i.repl()
