;; -*- Mode: Irken -*-

;; note: some `(logand 1 ...)` are here because irken's native int is an i63.

(define (NOT ctl)
  (logand 1 (logxor ctl 1)))

(define (MUX ctl x y)
  (logxor y (logand (- ctl) (logxor x y))))

(define (NEQ x y)
  (let ((q (logxor x y)))
    (logand 1 (>> (logior q (- q)) 31))))

(define (EQ x y)
  (NOT (NEQ x y)))

(define (GT x y)
  (let ((z (- y x)))
    (logand 1 (>> (logxor z (logand (logxor x y) (logxor x z))) 31))))

(define (GE x y) (NOT (GT y x)))
(define (LT x y) (GT y x))
(define (LE x y) (NOT (GT x y)))

(define (CMP x y)
  (logior
   (GT x y)
   (- (GT y x))))

(define (BIT-LENGTH x)
  (let ((k (NEQ x 0))
        (c 0))
    (set! c (GT x #xffff)) (set! x (MUX c (>> x 16) x)) (inc! k (<< c 4))
    (set! c (GT x #x00ff)) (set! x (MUX c (>> x  8) x)) (inc! k (<< c 3))
    (set! c (GT x #x000f)) (set! x (MUX c (>> x  4) x)) (inc! k (<< c 2))
    (set! c (GT x #x0003)) (set! x (MUX c (>> x  2) x)) (inc! k (<< c 1))
    (inc! k (GT x #x0001))
    k))

;; Note: these are constant-time only on modern processors.
(define (MUL31 x y)
  (* x y))

(define (MUL31-lo x y)
  (logand #x7fffffff (* x y)))

;; NOTE: in bearssl, this is a byte-level copy. here, it is used only
;;   for copying vectors.
(define (CCOPY ctl dst src len)
  (for-range i len
    (set! dst[i] (MUX ctl src[i] dst[i]))))

(define (i31-iszero x)
  (let ((z 0)
        (u (>> (+ x[0] 31) 5)))
    (while (> u 0)
      (set! z (logior z x[u]))
      (dec! u))
    (>> (lognot (logior z (- z))) 31)))

(define (i31-add a b ctl)
  (let ((cc 0)
        (m (>> (+ a[0] 63) 5)))
    (for-range* u 1 m
      (let ((naw (+ a[u] b[u] cc)))
        (set! cc (logand 1 (>> naw 31)))
        (set! a[u] (MUX ctl (logand naw #x7fffffff) a[u]))
        ))
    cc))

(define (i31-sub a b ctl)
  (let ((cc 0)
        (m (>> (+ a[0] 63) 5)))
    (for-range* u 1 m
      (let ((naw (- a[u] b[u] cc)))
        (set! cc (logand 1 (>> naw 31)))
        (set! a[u] (MUX ctl (logand naw #x7fffffff) a[u]))
        ))
    cc))

(define (i31-bit-length x xlen)
  (let ((tw 0)
        (twk 0))
    (while (> xlen 0)
      (dec! xlen)
      (let ((c (EQ tw 0))
            (w x[(+ xlen 1)]))
        (set! tw (MUX c w tw))
        (set! twk (MUX c xlen twk))))
    (+ (<< twk 5) (BIT-LENGTH tw))))

(define (i31-zero x bit-len)
  (set! x[0] bit-len)
  (for-range i (>> (+ bit-len 31) 5)
    (set! x[(+ i 1)] 0)))

(define (i31-rshift x count)
  (let ((len (>> (+ x[0] 31) 5))
        (r (>> x[1] count)))
    (when (not (= 0 len))
      (for-range* u 2 len
        (set! x[(- u 1)] (logand (logior (<< x[u] (- 31 count)) r) #x7FFFFFFF))
        (set! r (>> x[u] count)))
      (set! x[len] r)
      )))

(define (i31-reduce x a m)
  (let ((m-bitlen m[0])
        (mlen (>> (+ 31 m-bitlen) 5))
        (a-bitlen 0)
        (alen 0))
    (set! x[0] m-bitlen)
    (when (not (= m-bitlen 0))
      (set! a-bitlen a[0])
      (set! alen (>> (+ 31 a-bitlen) 5))
      (if (< a-bitlen m-bitlen)
          (for-range* i 1 mlen
            (set! x[i] (if (< i alen) a[i] 0)))
          (begin
            (for-range* i 1 (- mlen 1)
              (set! x[i] a[(+ i 1)]))
            (let loop ((u (+ 1 (- alen mlen))))
              (when (> u 0)
                (i31-muladd-small x a[u] m)
                (loop (- u 1)))))))))

;; smr: replacement for the uses of memmove
(define (scoot v amount start count)
  (match (int-cmp count 0) with
    (cmp:=) -> #u
    ;; scoot to the left
    ;; -2 [....01234...]
    ;; => [..01234xx...]
    (cmp:<)
    -> (for-range i count ;; left to right
         (set! v[(+ i start amount)]
               v[(+ i start)]))
    ;; scoot to the right
    ;; +2 [..01234.....]
    ;; => [..xx01234...]
    (cmp:>)
    -> (for-range-rev i count ;; right to left
         (set! v[(+ start i amount)]
               v[(+ start i)]))
    ))

;; version of br_divrem
(define (ctbig/divrem hi lo d)
  (let ((q 0)
        (ch (EQ hi d))
        (cf 0))
    (set! hi (MUX ch 0 hi))
    (for-range-rev k 32
      (let ((j (- 32 k))
            (w (logior (<< hi j) (>> lo k)))
            (ctl (logior (GE w d) (>> hi k)))
            (hi2 (>> (- w d) j))
            (lo2 (- lo (<< d k))))
        (set! hi (MUX ctl hi2 hi))
        (set! lo (MUX ctl lo2 lo))
        (set! q (logior q (<< ctl k)))))
    (set! cf (logior (GE lo d) hi))
    (set! q (logior q cf))
    (:tuple q (MUX cf (- lo d) lo))
    ))

(define (ctbig/div hi lo d)
  (match (ctbig/divrem hi lo d) with
    (:tuple q r) -> q))

(define (ctbig/rem hi lo d)
  (match (ctbig/divrem hi lo d) with
    (:tuple q r) -> r))

(define (i31-muladd-small x z m)
  (when (not (= m[0] 0))
    (if (<= m[0] 31)
        (let ((hi (>> x[1] 1))
              (lo (logior (<< x[1] 31) z)))
          (set! x[1] (ctbig/rem hi lo m[1])))
        (let ((mlen (>> (+ 31 m[0]) 5))
              (mblr (logand 31 m[0]))
              (hi x[mlen])
              (a0 0) (a1 0) (b0 0) (g 0) (q 0))
          (cond ((= mblr 0)
                 (set! a0 x[mlen])
                 (scoot x 1 1 (- mlen 1))
                 (set! x[1] z)
                 (set! a1 x[mlen])
                 (set! b0 m[mlen]))
                (else
                 (set! a0 (logand #x7fffffff (logior (<< x[mlen] (- 31 mblr)) (>> x[(- mlen 1)] mblr))))
                 (scoot x 1 1 (- mlen 1))
                 (set! x[1] z)
                 (set! a1 (logand #x7fffffff (logior (<< x[mlen] (- 31 mblr)) (>> x[(- mlen 1)] mblr))))
                 (set! b0 (logand #x7fffffff (logior (<< m[mlen] (- 31 mblr)) (>> m[(- mlen 1)] mblr))))
                 ))
          (set! g (ctbig/div (>> a0 1) (logior a1 (<< a0 31)) b0))
          (set! q (MUX (EQ a0 b0) #x7fffffff (MUX (EQ g 0) 0 (- g 1))))
          (let ((cc 0)
                (tb 1)
                (over 0)
                (under 0))
            (for-range* u 1 (+ 1 mlen)
              (let ((mw m[u])
                    (zl (+ (MUL31 mw q) cc))
                    (_ (set! cc (>> zl 31)))
                    (zw (logand zl #x7fffffff))
                    (xw x[u])
                    (nxw (logand #xffffffff (- xw zw))))
                (inc! cc (>> nxw 31))
                (set! nxw (logand nxw #x7fffffff))
                (set! x[u] nxw)
                (set! tb (MUX (EQ nxw mw) tb (GT nxw mw)))))
            (set! over (GT cc hi))
            (set! under (logand (lognot over) (logior tb (LT cc hi))))
            (i31-add x m over)
            (i31-sub x m under)
            #u
            )))))

(define (i31-ninv31 x)
  (let ((y (- 2 x)))
    (set! y (* y (- 2 (* y x))))
    (set! y (* y (- 2 (* y x))))
    (set! y (* y (- 2 (* y x))))
    (set! y (* y (- 2 (* y x))))
    (logand #x7fffffff (MUX (logand 1 x) (- y) 0))
    ))

(define (i32-zero x bit-len)
  (set! x[0] bit-len)
  (for-range i (>> (+ bit-len 31) 5)
    (set! x[(+ 1 i)] 0)))

(define (i31-montymul d x y m m0i)
  (let ((len (>> (+ m[0] 31) 5))
        (len4 (logand len (lognot 3)))
        (dh 0)
        (zh 0))
    (i32-zero d m[0])
    (for-range u len
      (let ((xu x[(+ u 1)])
            (f (MUL31-lo (+ d[1] (MUL31-lo x[(+ u 1)] y[1])) m0i))
            (r 0) (v 0) (z 0) (r 0))
        (while (< v len4)
          (set! z (+ d[(+ v 1)] (MUL31 xu y[(+ v 1)]) (MUL31 f m[(+ v 1)]) r))
          (set! r (logand #xffffffff (>> z 31)))
          (set! d[(+ v 0)] (logand z #x7fffffff))
          (set! z (+ d[(+ v 2)] (MUL31 xu y[(+ v 2)]) (MUL31 f m[(+ v 2)]) r))
          (set! r (logand #xffffffff (>> z 31)))
          (set! d[(+ v 1)] (logand z #x7fffffff))
          (set! z (+ d[(+ v 3)] (MUL31 xu y[(+ v 3)]) (MUL31 f m[(+ v 3)]) r))
          (set! r (logand #xffffffff (>> z 31)))
          (set! d[(+ v 2)] (logand z #x7fffffff))
          (set! z (+ d[(+ v 4)] (MUL31 xu y[(+ v 4)]) (MUL31 f m[(+ v 4)]) r))
          (set! r (logand #xffffffff (>> z 31)))
          (set! d[(+ v 3)] (logand z #x7fffffff))
          (inc! v 4))
        (while (< v len)
          (set! z (+ d[(+ v 1)] (MUL31 xu y[(+ v 1)]) (MUL31 f m[(+ v 1)]) r))
          (set! r (logand #xffffffff (>> z 31)))
          (set! d[v] (logand z #x7fffffff))
          (inc! v))
        (set! zh (+ dh r))
        (set! d[len] (logand #x7fffffff zh))
        (set! dh (>> zh 31))
        ))
    (set! d[0] m[0])
    (i31-sub d m (logior (NEQ dh 0) (NOT (i31-sub d m 0))))
    ))

(define (i31->monty x m)
  (let loop ((k (>> (+ m[0] 31) 5)))
    (when (> k 0)
      (i31-muladd-small x 0 m)
      (loop (- k 1)))))

(define (monty->i31 x m m0i)
  (let ((len (>> (+ m[0] 31) 5)))
    (for-range u len
      (let ((f (MUL31-lo x[1] m0i))
            (cc 0))
        (for-range v len
          (let ((z (+ x[(+ v 1)] (MUL31 f m[(+ v 1)]) cc)))
            (set! cc (>> z 31))
            (when (not (= 0 v))
              (set! x[v] (logand z #x7fffffff)))))
        (set! x[len] (logand #xffffffff cc)) ;; XXX is logand necessary?
        ))
    ))

;; `e` is a string in u256 form.
(define (i31-modpow x e elen m m0i t1 t2)
  (let ((mlen (>> (+ m[0] 63) 5))) ;; NOT in bytes
    (for-range i mlen
      (set! t1[i] x[i]))
    (i31->monty t1 m)
    (i31-zero x m[0])
    (set! x[1] 1)
    (for-range k (<< elen 3)
      (let ((ex (char->int (string-ref e (- elen 1 (>> k 3)))))
            (ctl (logand (>> ex (logand k 7)) 1)))
        (i31-montymul t2 x t1 m m0i)
        (CCOPY ctl x t2 mlen)
        (i31-montymul t2 t1 t1 m m0i)
        (for-range i mlen
          (set! t1[i] t2[i]))))))

(define (i31-mulacc d a b)
  (let ((alen (>> (+ a[0] 31) 5))
        (blen (>> (+ b[0] 31) 5)))
    (set! d[0] (+ a[0] b[0]))
    (for-range u blen
      (let ((f b[(+ 1 u)])
            (cc 0))
        (for-range v alen
          (let ((z (+ d[(+ 1 u v)] (MUL31 f a[(+ 1 v)]) cc)))
            (set! cc (>> z 31))
            (set! d[(+ 1 u v)] (logand z #x7fffffff))))
        (set! d[(+ 1 u alen)] (logand #xffffffff cc)) ;; mask needed?
        ))))

;; we violate the bignum abstraction a bit here because there is
;; [currently] no other way to mask out portions the lowest limb.
(define (digits->i31 v)
  (let ((v0 v)
        (i31 (list:nil)))
    (while (> (vector-length v0) 0)
      (PUSH i31 (logand v0[(- (vector-length v0) 1)] #x7fffffff))
      (set! v0 (digits-rshift v0 31)))
    (let ((nwords (length i31))
          (ri31 (reverse i31))
          (result (make-vector (+ 1 nwords) 0)))
      (for-range i nwords
        (set! result[(+ i 1)] (nth ri31 i)))
      (set! result[0] (i31-bit-length result nwords))
      result)))

(define (big->i31 n)
  (match n with
    (big:pos _) -> (digits->i31 (big->digits n))
    otherwise   -> (raise (:CTBig/ExpectedPositive otherwise))
    ))

(define (i31->big v)
  (let ((r (big:zero)))
    (for-range-rev* i 1 (vector-length v)
      (set! r (big-add (big-lshift r 31) (int->big v[i]))))
    r))