Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Basic matrix operations until QR and SVD algorithms #172

Merged
merged 10 commits into from
Feb 16, 2024
195 changes: 109 additions & 86 deletions spork/math.janet
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,10 @@
(def cells @[])
(forever
(set (cells x)
(*
bc
(math/pow p x)
(math/pow (- 1 p) (- t x))))
(*
bc
(math/pow p x)
(math/pow (- 1 p) (- t x))))
(+= cp (cells x))
(++ x)
(set bc (/ (* bc (+ (- t x) 1)) x))
Expand Down Expand Up @@ -1046,16 +1046,16 @@
(scalar c 1))

(defn trans
"Returns a new transposed matrix from `m`."
"Tansposes a list of row vectors."
[m]
(def [c r] (size m))
(def res (array/new c))
(for i 0 r
(def cr (array/new c))
(for j 0 c
(array/push cr (get-in m [j i])))
(array/push res cr))
res)
(map array ;m))

(defn row->col
"Transposes a row vector `xs` to col vector. Returns `xs` if it has higher dimensions."
[xs]
(case (type (xs 0))
:number (map array xs)
:array xs))

(defn sop
```
Expand Down Expand Up @@ -1089,35 +1089,41 @@
:number (sop m + a)
:array (mop m + a)))

(defn dot
"Computes dot product of matrices or vectors `x` and `y`."
[mx my]
(def [rx cx] (size mx))
(def [ry cy] (size my))
(assert (= cx ry) "matrices do not have right sizes for dot product")
(def res (array/new cy))
(for r 0 rx
(def cr (array/new cx))
(for c 0 cy
(var s 0)
(for rr 0 ry
(+= s (* (get-in mx [r rr]) (get-in my [rr c]))))
(array/push cr s))
(array/push res cr))
res)
(defn dot
"Dot product between two row vectors."
[v1 v2]
(apply + (map * v1 v2)))

(defn dot-fast
"Fast dot product between two row vectors of equal size."
[v1 v2]
(var t 0)
(for i 0 (length v1)
(+= t (* (get v1 i) (get v2 i))))
t)

(defn matmul
"Matrix multiplication between matrices `ma` and `mb`. Does not mutate."
[ma mb]
(map (fn [row-a]
(map (fn [col-b]
(apply + (map * row-a col-b)))
(trans mb)))
ma))

(defn mul
```
Multiply matrix `m` with `a` which can be matrix or vector.
Matrix `m` is mutated.
Multiply matrix `m` with `a` which can be matrix or a list.
Mutates `m`. A list `a` will be converted to column vector
then multiplifed from the right as `x * a`.
```
[m a]
(case (type a)
:number
(sop m * a)
:array
(if (number? (a 0))
(dot m (seq [x :in a] @[x]))
(matmul m (row->col a))
(mop m * a))))

(defn minor
Expand Down Expand Up @@ -1355,82 +1361,95 @@
(array/concat res (factor-pollard x))))
res)

(defn matmul [a b]
(let [transpose (fn [m] (apply map array m))
b-t (transpose b)]
(map (fn [row-a]
(map (fn [col-b]
(apply + (map * row-a col-b)))
b-t))
a)))

(defn dot [v1 v2]
(apply + (map * v1 v2)))

(defn scale [v k]
(defn scale
"Scale a vector `v` by a number `k`."
[v k]
(map (fn [x] (* x k)) v))

(defn subtract [v1 v2]
(defn subtract
"Elementwise subtract vector `v2` from `v1`."
[v1 v2]
(map - v1 v2))

(defn copy [xs] (if (= :ta/view (type xs)) (:slice xs) (array/slice xs)))
(defn copy
"Deep copy an array or view `xs`."
[xs]
(if (= :ta/view (type xs)) (:slice xs) (array/slice xs)))

(defn sign [x]
(if (>= x 0) 1 -1))
(defn sign
"Sign function."
[x] (cmp x 0))

(defn trans-m [m] (apply map array m))
(defn trans-v [xs] (map array xs))
(defn outer
"Outer product of vectors `v1` and `v2`."
[v1 v2]
(matmul (map array v1) (array v2)))

(defn outer [xs] (matmul (map array xs) (array xs)))

(defn unit-e [n k]
(defn unit-e
"Unit vector of `n` dimensions along dimension `k`."
[n k]
(update-in
(zero n) [k] (fn [x] 1)))

(defn normalize-v [xs]
(defn normalize-v
"Returns normalized vector of `xs` by Euclidian (L2) norm."
[xs]
(map |(/ $0 (math/sqrt (dot xs xs))) xs))

(defn rbind [m1 m2]

(defn join-rows
"Stack vertically rows of two matrices."
[m1 m2]
(array/concat m1 m2))

(defn cbind [m1 m2]
(let [tm1 (trans-m m1)
tm2 (trans-m m2)]
(trans-m (rbind tm1 tm2))))
(defn join-cols
"Stack horizontally columns of two matrices."
[m1 m2]
(map join-rows m1 m2))

(defn squeeze
"Concatenate a list of rows into a single row. Does not mutate `m`."
[m]
(array/concat @[] ;m))

(defn flipud
"Flip a matrix upside-down."
[m]
(reverse m))

(defn flipud [m] (reverse m))
(defn fliplr [m] (-> m
trans-m
flipud
trans-m))
(defn fliplr
"Flip a matrix leftside-right."
[m]
(map reverse m))

(defn expand-m
"Returns a new transposed matrix from `m`."
"Embeds a matrix `m` inside an identity matrix of size n."
[n m]
(let [I (ident n)
left (rbind I (zero n (rows m)))
right (rbind (zero (cols m) n) m)]
(cbind left right)))
left (join-rows I (zero n (rows m)))
right (join-rows (zero (cols m) n) m)]
(join-cols left right)))

(defn slice-m
[m rslice cslice]
(-> m (array/slice ;rslice)
trans-m
(array/slice ;cslice)
trans-m))
"Slice a matrix `m` by rows and columns."
[m rslice cslice]
(-> m
(array/slice ;rslice)
trans
(array/slice ;cslice)
trans))


(defn- qr1
(defn qr1
"Transform using Householder reflections by one step."
[m]
(let [x ((trans-m m) 0) # take first column
(let [x ((trans m) 0) # take first column
k 0
a (* -1 (sign (x k)) (math/sqrt (dot x x)))
e1 (unit-e (length x) 0)
u (subtract x (map |(* $ a) e1)) # (mul e1 a)
v (normalize-v u)
I (ident (length u))
Q (mop I - (sop (outer v) * 2))
Q (mop I - (sop (outer v v) * 2))
Qm (matmul Q m)
m^ (slice-m Qm [1] [1])]
{:Q Q
Expand Down Expand Up @@ -1475,22 +1494,26 @@
(var res (qr R1))
(set Q1 (res :Q))
(set R1 (res :R))
(var res^ (qr (trans-m R1)))
(var res^ (qr (trans R1)))
(set Q2 (res^ :Q))
(set R2 (res^ :R))
(set R1 (trans-m R2))
(set R1 (trans R2))
(set U (matmul U Q1))
(set V (matmul V Q2))))
{:U U
:S R1
:V V})

(defn m-approx=
"Evaluates a matrix (list of row vectors) for equivalence within epsilon."
[m1 m2]
# TODO:
(let [v1 (apply array/concat m1)
v2 (apply array/concat m2)
"Compares two matrices of equal size for equivalence within epsilon."
[m1 m2 &opt tolerance]
(let [v1 (squeeze m1)
v2 (squeeze m2)
b (map approx-eq v1 v2)]
(and (= (length v1) (length v2))
(every? b))))
(every? b))))

(let [m3 @[@[1 2 3] @[4 5 6] @[7 8 9]]]
(assert (m-approx= (matmul m3 (ident (rows m3)))
m3)
"matmul identity left: this test succeeds here but fails in suite-math.janet"))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any idea why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Caused by stateful computation due to join-rows not being idempotent (array/concat mutates first argument).

Loading