;; -*- Mode: Irken -*-
(include "self/lisp_reader.scm")
(include "self/mbe.scm")
(include "self/types.scm")
(include "self/context.scm")
(include "self/match.scm")
;; scan for datatypes, definitions, etc..
;; and do transformations that can't be handled by the macro system.
(define (transformer context)
(define counter 0)
(define (go exp)
(print-string "go:")
(pp 0 exp) (newline)
(let ((expanded
(match exp with
(sexp:list ()) -> (sexp:list '())
(sexp:list (one)) -> (expand one)
(sexp:list exps) -> (expand-body (find-declarations exps))
_ -> (error1 "unexpected s-expression in transformer:" exp)
)))
(wrap-with-constructors expanded)
))
(define (wrap-fix names inits body)
(if (> (length names) 0)
(sexp (sexp:symbol 'fix)
(sexp:list names)
(sexp:list inits)
body)
body))
(define (wrap-begin exps)
(sexp:list (list:cons (sexp:symbol 'begin) exps)))
(define (wrap-definitions defs exps)
(let ((names '())
(inits '()))
(for-each
(lambda (x)
(match x with
(:pair name init)
-> (begin (PUSH names (sexp:symbol name))
(PUSH inits (expand init)))))
defs)
(wrap-fix (reverse names) (reverse inits) (expand (wrap-begin exps)))))
(define (expand-body exps)
(find-definitions exps wrap-definitions))
(define (find-declarations exps)
(define recur
() acc -> (reverse acc)
(hd . tl) acc
-> (match hd with
(sexp:list ((sexp:symbol 'datatype) . dtl))
-> (begin (parse-datatype dtl) (recur tl acc))
(sexp:list ((sexp:symbol 'defmacro) . dtl))
-> (begin (parse-defmacro dtl) (recur tl acc))
_ -> (recur tl (list:cons hd acc))))
(recur exps '()))
(define (find-definitions exps k)
(define recur
defs exps () -> (k (reverse defs) (reverse exps))
defs exps (hd . tl) -> (match hd with
(sexp:list ((sexp:symbol 'define) . body))
-> (recur (list:cons (parse-define body) defs) exps tl)
exp -> (recur defs (list:cons exp exps) tl)
))
(recur '() '() exps))
(define expand-field
(field:t name exp) -> (field:t name (expand exp)))
(define (expand exp)
;;(print-string "expanding... ") (unread exp) (newline)
(match exp with
(sexp:symbol _) -> exp
(sexp:string _) -> exp
(sexp:char _) -> exp
(sexp:bool _) -> exp
(sexp:int _) -> exp
(sexp:undef) -> exp
(sexp:list l) -> (maybe-expand l)
(sexp:vector rands) -> (sexp:vector (map expand rands))
(sexp:record fields) -> (sexp:record (map expand-field fields))
(sexp:cons _ _) -> exp
(sexp:attr exp sym) -> (sexp:attr (expand exp) sym)
))
(define (maybe-expand l)
(match l with
() -> (sexp:list '())
(rator . rands)
-> (match rator with
(sexp:symbol sym)
-> (match (alist/lookup transform-table sym) with
(maybe:yes fun) -> (fun rands)
(maybe:no) -> (match (alist/lookup context.macros sym) with
(maybe:yes macro) -> (expand (macro.apply (sexp:list l)))
(maybe:no) -> (sexp:list (list:cons rator (map expand rands)))))
_ -> (sexp:list (map expand l)))))
(define expand-if
(tst then) -> (sexp1 'if (LIST (expand tst) (expand then) (sexp:undef)))
(tst then else) -> (sexp1 'if (LIST (expand tst) (expand then) (expand else)))
x -> (error1 "malformed <if>" x)
)
(define expand-set!
((sexp:attr lhs attr) val)
-> (sexp1 '%%record-set (LIST (sexp:symbol attr) (expand lhs) (expand val)))
((sexp:list ((sexp:symbol '%%array-ref) lhs idx)) val)
-> (sexp1 '%%array-set (LIST (expand lhs) (expand val) (expand idx)))
exp
-> (sexp:list (list:cons (sexp:symbol 'set!) exp))
)
(define expand-begin
() -> (error "empty BEGIN")
(one) -> (expand one)
l -> (sexp1 'begin (map expand l))
)
(define expand-quote
(one) -> (build-literal one #t #f)
x -> (error1 "bad args to QUOTE" x))
(define expand-literal
(one) -> (build-literal one #f #f)
x -> (error1 "bad args to LITERAL" x))
(define expand-backquote
(one) -> (build-literal one #t #t)
x -> (error1 "bad args to BACKQUOTE" x))
(define expand-lambda
(formals . body) -> (exp-function (sexp:symbol 'lambda) formals (expand (sexp1 'begin body)))
x -> (error1 "malformed LAMBDA" x))
(define expand-function
(name formals . body) -> (exp-function name formals (expand (sexp1 'begin body)))
x -> (error1 "malformed FUNCTION" x))
(define (exp-function name formals body)
(sexp1 'function (LIST name formals body)))
;; collect tag/alt pairs from vcase
(define (split-alts pairs k)
(let loop ((tags '())
(formals '())
(alts '())
(pairs pairs)
)
(match pairs with
() -> (k tags formals alts (maybe:no))
;; (((:tag f0 f1 ...) body ...) ...)
((sexp:list ((sexp:list ((sexp:cons _ tag) . args)) . code)) . pairs)
-> (loop (list:cons tag tags)
(list:cons args formals)
(list:cons (expand (sexp1 'begin code)) alts)
pairs)
;; ((else body ...))
((sexp:list ((sexp:symbol 'else) . else-code)))
-> (k tags formals alts (maybe:yes (expand (sexp1 'begin else-code))))
_ -> (begin (unread (car pairs)) (error1 "split-alts" pairs)))))
;; (nvcase type x
;; ((<select0> <formal0> <formal1> ...) <body0>)
;; ((<select1> <formal0> <formal1> ...) <body1>)
;; ...
;; [(else <body2>)]
;; )
;; =>
;; (nvcase type x
;; ((let ((f0 x.0) (f1 x.1) (f2 x.2)) <body0>) ...))
;;
(define (make-nvget dt label index value)
(sexp (sexp:symbol '%nvget)
(sexp (sexp:cons dt label) (sexp:int index))
value))
(define expand-vcase
((sexp:symbol dt) value . alts)
-> (split-alts
alts
(lambda (tags formals alts ealt?)
(let ((alts0
(map-range
i (length alts)
(let ((alt-formals (nth formals i))
(nformals (length alt-formals))
(tag (nth tags i))
(binds (let loop ((j 0)
(r '()))
(if (= j nformals)
(reverse r)
(match (nth alt-formals j) with
(sexp:symbol '_) -> (loop (+ j 1) r)
formal -> (loop (+ j 1)
(list:cons
(sexp formal (make-nvget dt tag j value))
r)))))))
(if (not (null? binds))
;; (let ((f0 (%nvget (list:cons 0) value))
;; (f1 (%nvget (list:cons 1) value)))
;; body)
(sexp:list
(append (LIST (sexp:symbol 'let) (sexp:list binds))
(LIST (nth alts i))))
;; body
(nth alts i))))))
(for-each (lambda (x) (unread x) (newline)) alts0)
(sexp (sexp:symbol '%nvcase)
(sexp:symbol dt)
(expand value)
(sexp:list (map sexp:symbol tags))
(sexp:list alts0)
(match ealt? with
(maybe:no) -> match-error
(maybe:yes ealt) -> ealt))
)))
x -> (error1 "expand-vcase" x))
(define (build-literal ob as-list? backquote?)
#u)
(define parse-defmacro
((sexp:symbol name) . exps)
-> (let ((macro
(make-macro
name
(let loop ((exps exps))
(match exps with
() -> '()
(in-pat (sexp:symbol '->) out-pat . rest)
-> (list:cons (:pair in-pat out-pat) (loop rest))
_ -> (error1 "malformed macro definition:" exps))))))
(alist/push context.macros name macro))
x -> (error1 "malformed macro definition:" x)
)
(define (make-datatype tvars name)
(let ((alt-map (alist-maker))
(nalts 0))
(define (get tag)
(alt-map::get-err tag "no such alt in datatype"))
(define (add alt)
(alt-map::add alt.name alt)
(set! alt.index nalts)
(set! nalts (+ 1 nalts))
)
(define (iterate p)
(alt-map::iterate p))
(define (get-nalts) nalts)
(define (get-scheme)
(let ((tvars (tvars::values)))
(:scheme tvars (pred name tvars))))
(define (get-tvars)
(tvars::values))
(define (get-alt-scheme tag)
(let ((alt (alt-map::get-err tag "no such alt in datatype")))
(let ((dtscheme (pred name (tvars::values))))
(:scheme (tvars::values) (arrow dtscheme alt.types)))))
{ name=name
get=get
add=add
iterate=iterate
get-nalts=get-nalts
get-scheme=get-scheme
get-alt-scheme=get-alt-scheme
get-tvars=get-tvars }
))
(define (wrap-with-constructors body)
(let ((names '())
(constructors '()))
(alist/iterate
(lambda (name dt)
(dt.iterate
(lambda (tag alt)
(PUSH names (string->symbol (format (sym dt.name) ":" (sym tag))))
(PUSH constructors (make-constructor dt.name tag alt.arity)))))
context.datatypes)
(wrap-fix (map sexp:symbol names) constructors body)))
(define (make-constructor dt tag arity)
(let ((args (map-range i arity (sexp:symbol (string->symbol (format "arg" (int i)))))))
(sexp (sexp:symbol 'function)
(sexp:symbol (string->symbol (format (sym dt) ":" (sym tag))))
(sexp:list args)
(sexp:list (list:cons (sexp:cons dt tag) args)))))
(define (make-alt tvars tag types)
(let ((types (map (lambda (t) (parse-type* t tvars)) types))
(arity (length types)))
{name=tag
types=types
arity=arity
index=0}))
(define parse-datatype
((sexp:symbol name) . subs)
-> (let ((tvars (alist-maker))
(dt (make-datatype tvars name)))
(for-each
(lambda (sub)
(match sub with
(sexp:list ((sexp:cons 'nil tag) . types)) -> (dt.add (make-alt tvars tag types))
x -> (error1 "malformed alt in datatype" x)))
(reverse subs)) ;; preserve user order of alts
(alist/push context.datatypes name dt)
)
x -> (error1 "malformed datatype" x)
)
(define parse-define
;; (define name ...)
((sexp:symbol name) . body)
-> (if (member? (sexp:symbol '->) body sexp=?)
;; pattern-matching expression
(parse-pattern-matching-define name body)
;; normal definition
(parse-no-formals-define name body))
;; (define (name arg ...) ...)
((sexp:list ((sexp:symbol name) . formals)) . body)
-> (parse-normal-definition name formals body)
x -> (error1 "malformed <define>" x))
(define (parse-pattern-matching-define name body)
(match (compile-pattern context expand body) with
(:pair vars body0)
-> (:pair name
(sexp (sexp:symbol 'function)
(sexp:symbol name)
(sexp:list (map sexp:symbol vars))
(expand body0)))))
(define (parse-no-formals-define name body)
(:pair name (sexp:list body)))
(define (parse-normal-definition name formals body)
(:pair name (sexp (sexp:symbol 'function)
(sexp:symbol name)
(sexp:list formals)
;; note: expand-body returns one sexp
(expand-body body))))
(define transform-table
(literal
(alist/make
('if expand-if)
('set! expand-set!)
('begin expand-begin)
('lambda expand-lambda)
('function expand-function)
('vcase expand-vcase)
)))
go
)
(define (print-datatype dt)
(print-string "(datatype ")
(printn dt.name)
(dt.iterate
(lambda (tag alt)
(print-string (format " (:" (sym tag) " " (join type-repr " " alt.types) ")\n"))))
(print-string " )\n")
)
;; (define (test-transform)
;; (let ((context (make-context))
;; (transform (transformer context))
;; (tl (sexp:list (read-file sys.argv[1])))
;; (exp0 (transform tl)))
;; (unread exp0)
;; (newline)
;; (print-string "repr (exp0) =>\n")
;; (pp 0 exp0)
;; (print-string (format "\npp-size=" (int (pp-size exp0)) "actual=" (int (string-length (repr exp0))) "\n"))
;; (alist/iterate (lambda (name dt) (print-datatype dt)) context.datatypes)
;; (newline)
;; (alist/iterate (lambda (name macro) (macro.unread)) context.macros)
;; (newline)
;; ))
;; (include "lib/alist2.scm")
;; (test-transform)