Skip to content

Commit

Permalink
Merge pull request #172 from hyiltiz/matrix
Browse files Browse the repository at this point in the history
Feature: Basic matrix operations until QR and SVD algorithms
  • Loading branch information
bakpakin authored Feb 16, 2024
2 parents c66a18f + 00f3760 commit caf9220
Show file tree
Hide file tree
Showing 2 changed files with 288 additions and 39 deletions.
223 changes: 189 additions & 34 deletions spork/math.janet
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,9 @@
(def cells @[])
(forever
(set (cells x)
(*
bc
(math/pow p x)
(math/pow (- 1 p) (- t x))))
(* bc
(math/pow p x)
(math/pow (- 1 p) (- t x))))
(+= cp (cells x))
(++ x)
(set bc (/ (* bc (+ (- t x) 1)) x))
Expand Down Expand Up @@ -1026,7 +1025,7 @@
[c &opt r]
(def v (array/new-filled c 0))
(if r
(seq [_ :range [0 c]] (array/slice v))
(seq [_ :range [0 r]] (array/slice v))
v))

(defn scalar
Expand All @@ -1046,16 +1045,16 @@
(scalar c 1))

(defn trans
"Returns a new transposed matrix from `m`."
"Tansposes a list of row vectors."
[m]
(def [c r] (size m))
(def res (array/new c))
(for i 0 r
(def cr (array/new c))
(for j 0 c
(array/push cr (get-in m [j i])))
(array/push res cr))
res)
(map array ;m))

(defn row->col
"Transposes a row vector `xs` to col vector. Returns `xs` if it has higher dimensions."
[xs]
(case (type (xs 0))
:number (map array xs)
:array xs))

(defn sop
```
Expand All @@ -1067,7 +1066,8 @@
(if-not (empty? a) |(op $ ;a) op))
(for i 0 (cols m)
(for j 0 (rows m)
(update-in m [i j] opa))))
(update-in m [i j] opa)))
m)

(defn mop
```
Expand All @@ -1077,7 +1077,7 @@
[m op a]
(for i 0 (cols m)
(for j 0 (rows m)
(update-in m [j i] op (get-in a [j i])))))
(update-in m [j i] op (get-in a [j i])))) m)

(defn add
```
Expand All @@ -1090,34 +1090,40 @@
:array (mop m + a)))

(defn dot
"Computes dot product of matrices or vectors `x` and `y`."
[mx my]
(def [rx cx] (size mx))
(def [ry cy] (size my))
(assert (= cx ry) "matrices do not have right sizes for dot product")
(def res (array/new cy))
(for r 0 rx
(def cr (array/new cx))
(for c 0 cy
(var s 0)
(for rr 0 ry
(+= s (* (get-in mx [r rr]) (get-in my [rr c]))))
(array/push cr s))
(array/push res cr))
res)
"Dot product between two row vectors."
[v1 v2]
(apply + (map * v1 v2)))

(defn dot-fast
"Fast dot product between two row vectors of equal size."
[v1 v2]
(var t 0)
(for i 0 (length v1)
(+= t (* (get v1 i) (get v2 i))))
t)

(defn matmul
"Matrix multiplication between matrices `ma` and `mb`. Does not mutate."
[ma mb]
(map (fn [row-a]
(map (fn [col-b]
(apply + (map * row-a col-b)))
(trans mb)))
ma))

(defn mul
```
Multiply matrix `m` with `a` which can be matrix or vector.
Matrix `m` is mutated.
Multiply matrix `m` with `a` which can be matrix or a list.
Mutates `m`. A list `a` will be converted to column vector
then multiplifed from the right as `x * a`.
```
[m a]
(case (type a)
:number
(sop m * a)
:array
(if (number? (a 0))
(dot m (seq [x :in a] @[x]))
(matmul m (row->col a))
(mop m * a))))

(defn minor
Expand Down Expand Up @@ -1354,3 +1360,152 @@
(if (> x one)
(array/concat res (factor-pollard x))))
res)

(defn scale
"Scale a vector `v` by a number `k`."
[v k]
(map (fn [x] (* x k)) v))

(defn subtract
"Elementwise subtract vector `v2` from `v1`."
[v1 v2]
(map - v1 v2))

(defn copy
"Deep copy an array or view `xs`."
[xs]
(if (= :ta/view (type xs)) (:slice xs) (array/slice xs)))

(defn sign
"Sign function."
[x] (cmp x 0))

(defn outer
"Outer product of vectors `v1` and `v2`."
[v1 v2]
(matmul (map array v1) (array v2)))

(defn unit-e
"Unit vector of `n` dimensions along dimension `k`."
[n k]
(update-in
(zero n) [k] (fn [x] 1)))

(defn normalize-v
"Returns normalized vector of `xs` by Euclidian (L2) norm."
[xs]
(map |(/ $0 (math/sqrt (dot xs xs))) xs))

(defn join-rows
"Stack vertically rows of two matrices."
[m1 m2]
(array/concat @[] m1 m2))

(defn join-cols
"Stack horizontally columns of two matrices."
[m1 m2]
(map join-rows m1 m2))

(defn squeeze
"Concatenate a list of rows into a single row. Does not mutate `m`."
[m]
(array/concat @[] ;m))

(defn flipud
"Flip a matrix upside-down."
[m]
(reverse m))

(defn fliplr
"Flip a matrix leftside-right."
[m]
(map reverse m))

(defn expand-m
"Embeds a matrix `m` inside an identity matrix of size n."
[n m]
(let [I (ident n)
left (join-rows I (zero n (rows m)))
right (join-rows (zero (cols m) n) m)]
(join-cols left right)))

(defn slice-m
"Slice a matrix `m` by rows and columns."
[m rslice cslice]
(-> m
(array/slice ;rslice)
trans
(array/slice ;cslice)
trans))


(defn qr1
"Transform using Householder reflections by one step."
[m]
(let [x ((trans m) 0) # take first column
k 0
a (* -1 (sign (x k)) (math/sqrt (dot x x)))
e1 (unit-e (length x) 0)
u (subtract x (map |(* $ a) e1)) # (mul e1 a)
v (normalize-v u)
I (ident (length u))
Q (mop I - (sop (outer v v) * 2))
Qm (matmul Q m)
m^ (slice-m Qm [1] [1])]
{:Q Q
:m^ m^}))


(defn qr
```
Stable and robust QR decomposition of a matrix.
Decompose a matrix using Householder transformations. O(n^3).
```
[m]
(var m^ m)
(var Qs (seq [i :range [0 (min (- (rows m) 1) (cols m))]]
(def res (qr1 m^))
(set m^ (res :m^))
(def Q^ (expand-m i (res :Q)))
Q^))
(def I (ident (cols Qs)))
(var Q (reduce matmul I Qs))
(var R (reduce matmul I (array/concat (reverse Qs) (array m))))
{:Q Q
:R R})

(defn svd
```
Simple Singular-Value-Decomposition based on repeated QR decomposition. The algorithm converges at O(n^3).
```
[m &opt n-iter]
(def n-iter 100)
(var U (ident (rows m)))
(var V U)
(var Q1 U)
(var Q2 U)
(var R1 m)
(var R2 U)
(var Q1 U)
(loop [i :range [0 n-iter]]
(var res (qr R1))
(set Q1 (res :Q))
(set R1 (res :R))
(var res^ (qr (trans R1)))
(set Q2 (res^ :Q))
(set R2 (res^ :R))
(set R1 (trans R2))
(set U (matmul U Q1))
(set V (matmul V Q2)))
{:U U
:S R1
:V V})

(defn m-approx=
"Compares two matrices of equal size for equivalence within epsilon."
[m1 m2 &opt tolerance]
(let [v1 (squeeze m1)
v2 (squeeze m2)
b (map approx-eq v1 v2)]
(and (= (length v1) (length v2))
(every? b))))
Loading

0 comments on commit caf9220

Please sign in to comment.