;; -*- Mode: Irken -*-
(include "lib/counter.scm")
(include "lib/stack.scm")
(datatype field-pattern
(:t symbol pattern)
)
(datatype pattern
(:literal sexp)
(:variable symbol)
(:constructor symbol symbol (list pattern))
(:record (list field-pattern))
)
(datatype rule
(:t (list pattern) sexp))
(define rule->code (rule:t _ code) -> code)
(define rule->pats (rule:t pats _) -> pats)
(define match-error (sexp (sexp:symbol '%match-error) (sexp:bool #f)))
(define match-fail (sexp (sexp:symbol '%fail) (sexp:bool #f)))
(define (compile-pattern context expander exp)
(define (parse-pattern exp)
(define parse-field-pattern
(field:t name pat) -> (field-pattern:t name (kind pat)))
(define kind
(sexp:symbol s) -> (pattern:variable s)
(sexp:record fields) -> (pattern:record (map parse-field-pattern fields))
(sexp:list l)
-> (match l with
() -> (pattern:constructor 'list 'nil '())
((sexp:symbol 'quote) (sexp:symbol s)) -> (pattern:literal (sexp:symbol s))
((sexp:cons dt alt) . args) -> (pattern:constructor dt alt (map kind args))
((sexp:symbol '.) last) -> (kind last)
(hd . tl) -> (pattern:constructor 'list 'cons (LIST (kind hd) (kind (sexp:list tl))))
_ -> (error1 "malformed pattern" l))
x -> (pattern:literal x))
(kind exp))
;; (p0 p1 p2 -> r0 ...)
(define (parse-match expander body)
(let loop ((patterns '())
(rules '())
(l body))
(match l with
() -> (reverse rules)
((sexp:symbol '->) code . tl)
-> (loop '() (list:cons (rule:t (reverse patterns) (expander code)) rules) tl)
(pat . tl)
-> (loop (list:cons (parse-pattern pat) patterns) rules tl))))
;; XXX redo with format after writing <sexp-repr>
(define (dump-pat p)
(define ps print-string)
(define dump-field
(field-pattern:t name vpat)
-> (begin (print name)
(ps "=")
(dump-pat vpat)))
(match p with
(pattern:literal exp)
-> (begin (ps "L") (unread exp))
(pattern:variable var)
-> (print var)
(pattern:constructor dt alt args)
-> (begin (ps "(") (print dt) (ps ":") (print alt) (ps " ")
(for-each (lambda (x) (dump-pat x) (ps " ")) args)
(ps ")"))
(pattern:record fpats)
-> (begin (ps "{")
(for-each (lambda (fp) (dump-field fp) (ps " ")) fpats)
(ps "}"))
_ -> (error1 "NYI" p)))
(define pattern->kind
(pattern:literal _) -> 'literal
(pattern:variable _) -> 'variable
(pattern:constructor _ _ _ ) -> 'constructor
(pattern:record _) -> 'record
)
;; pull the first pattern out of each rule
(define remove-first-pat
(rule:t (pat . pats) code)
-> (rule:t pats code)
_ -> (error "remove-first-pat: empty pats?"))
(define first-pattern-kind
(rule:t (pat0 . pats) _) -> (pattern->kind pat0)
_ -> (error "empty pattern list?"))
(define (compare-first-patterns a b)
(eq? (first-pattern-kind a)
(first-pattern-kind b)))
(define (compile-match vars rules default)
(match vars rules with
;; the 'empty rule'
() () -> default
() (rule . _) -> (rule->code rule)
_ _ ->
;; group the rules by kind of first pattern
;; reverse the rules so we compile inside-out, starting with <default>
(let ((groups (pack (reverse rules) compare-first-patterns)))
(let loop ((l groups)
(default default))
(match l with
() -> default
(group . rest)
-> (loop rest
;; choose a rule to apply
(match (first-pattern-kind (car group)) with
'literal -> (constant-rule vars group default)
'variable -> (variable-rule vars group default)
'constructor -> (constructor-rule vars group default)
'record -> (record-rule vars group default)
_ -> (impossible)))
)))))
(define (fatbar e1 e2)
(cond ((eq? e1 match-fail) e2)
((eq? e2 match-fail) e1)
(else
(sexp1 '%fatbar (LIST (sexp:bool #f) e1 e2)))))
(define (subst var0 pat code)
(match pat with
(pattern:variable var1)
;; record a subst to be applied during node building (unless it's a wildcard pattern)
-> (if (not (eq? var0 '_))
(sexp (sexp:symbol 'let_subst)
(sexp (sexp:symbol var1)
(sexp:symbol var0)) code)
code)
_ -> (impossible)
))
;; if every rule begins with a variable, we can remove that column
;; from the set of patterns and substitute the var within each body
(define (variable-rule vars rules default)
(let ((var0 (car vars))
(rules0 (map (lambda (rule)
(match rule with
(rule:t pats code)
-> (rule:t (cdr pats) (subst var0 (car pats) code))))
rules)))
(compile-match (cdr vars) rules0 default)))
(define pattern->literal
(pattern:literal exp) -> exp
_ -> (error "not a literal pattern"))
(define (first-literal=? r0 r1)
(match r0 r1 with
(rule:t pats0 _) (rule:t pats1 _)
-> (sexp=? (pattern->literal (car pats0))
(pattern->literal (car pats1)))))
(define (constant-rule vars rules default)
;; group runs of the same literal together
(let loop ((groups (pack rules first-literal=?))
(default default))
(match groups with
() -> default
(rules0 . groups)
-> (let ((lit (pattern->literal (car (rule->pats (car rules0)))))
(comp-fun
(match lit with
(sexp:string _) -> (sexp:symbol 'string=?)
_ -> (sexp:symbol 'eq?))))
(loop groups
(fatbar (sexp (sexp:symbol 'if)
(sexp comp-fun (sexp:symbol (car vars)) lit)
(compile-match (cdr vars) (map remove-first-pat rules0) match-fail)
match-fail)
default))))))
(define (record-rule vars rules default)
(error "NYI"))
;; sort a collection <l> into lists with matching <p>
;; <p> must return an eq?-compatible object. returns an alist of stacks.
(define (collect p l)
(let loop ((acc (alist/make))
(l l))
(match l with
() -> acc
(hd . tl)
-> (let ((key (p hd)))
(match (alist/lookup acc key) with
(maybe:no) -> (let ((stack (make-stack)))
(stack.push hd)
(loop (alist:entry key stack acc) tl))
(maybe:yes stack) -> (begin (stack.push hd) (loop acc tl)))))))
(define pattern->dt
(pattern:constructor dt _ _) -> dt
_ -> (error "not a constructor pattern"))
(define pattern->alt
(pattern:constructor _ alt _) -> alt
_ -> (error "not a constructor pattern"))
(define pattern->subs
(pattern:constructor _ _ subs) -> subs
_ -> (error "not a constructor pattern"))
(define rule->constructor-dt
(rule:t pats _)
-> (pattern->dt (car pats)))
(define rule->constructor-alt
(rule:t pats _)
-> (pattern->alt (car pats)))
(define (sort-constructor-rules rules)
;; first, make sure we're all on the same datatype
(let ((by-dt (collect rule->constructor-dt rules))
(keys (alist->keys by-dt)))
(if (not (= (length keys) 1))
(error1 "more than one datatype in pattern match" keys)
(collect rule->constructor-alt rules))))
(define (constructor-rule vars rules default)
(let ((dtname (rule->constructor-dt (car rules)))
(alts (sort-constructor-rules rules))
(nalts 0)
(dt (alist/get context.datatypes
(rule->constructor-dt (car rules))
"unknown datatype"))
(default0 (if (sexp=? default match-error) default match-fail))
(cases '())
)
(alist/iterate
(lambda (tag rules-stack)
(let ((alt (dt.get tag))
(vars0 (nthunk alt.arity new-match-var))
(wild (make-vector alt.arity #t))
(rules1 '()))
(set! nalts (+ nalts 1))
(define frob-rule
(rule:t pats code)
-> (let ((subs (pattern->subs (car pats))))
(if (not (= (length subs) alt.arity))
(error1 "arity mistmatch in variant pattern" rules))
(PUSH rules1 (rule:t (append (pattern->subs (car pats)) (cdr pats)) code))
(for-range i alt.arity (set! wild[i] (match (nth subs i) with
(pattern:variable '_) -> #t
_ -> #f)))))
(for-each frob-rule (rules-stack.get))
;; if every pattern has a wildcard for this arg of the constructor,
;; then use '_' rather than the symbol we generated.
(let ((vars1 (map-range i alt.arity (if wild[i] '_ (nth vars0 i)))))
(PUSH cases
;; ((:tag var0 var1 ...) (match ...))
(sexp
(sexp:list
(list:cons (sexp:cons 'nil alt.name) (map sexp:symbol vars1)))
(compile-match (append vars0 (cdr vars)) (reverse rules1) default0))))))
alts)
(let ((result
(if (not (eq? dt.name 'nil))
(begin (if (< nalts (dt.get-nalts))
(PUSH cases (sexp (sexp:symbol 'else) default0)))
(sexp:list (append (LIST (sexp:symbol 'vcase) (sexp:symbol dt.name) (sexp:symbol (car vars)))
(reverse cases))))
(sexp:list (append (LIST (sexp:symbol 'vcase) (sexp:symbol (car vars)))
(reverse cases))))
))
(if (not (eq? default match-error))
(fatbar result default)
result))))
(define dump-rule
(rule:t pats code)
-> (begin (for-each (lambda (p)
(dump-pat p)
(print-string " ")) pats)
(print-string "-> ")
(unread code)))
(define match-counter (make-counter 0))
(define (new-match-var)
(string->symbol (format "m" (int (match-counter.inc)))))
(define nthunk
0 p -> '()
n p -> (list:cons (p) (nthunk (- n 1) p)))
(let ((rules (parse-match expander exp)))
(for-each (lambda (rule)
(newline)
(dump-rule rule)) rules)
(newline)
(let ((npats (length (rule->pats (car rules))))
(vars (nthunk npats new-match-var)))
(:pair vars (compile-match vars rules match-error))))
)