forked from rossng/depennd
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Tensor.idr
61 lines (49 loc) · 1.78 KB
/
Tensor.idr
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
module Tensor
import Data.Fin
import Data.Vect
import Matrix
-- Not sure how to make this work cleanly
tensor : Vect n Nat -> Type -> Type
tensor [] a = a
tensor (m :: ms) a = Vect m (tensor ms a):
tmap : (t -> u) -> tensor s t -> tensor s u
tmap {s = []} f v = f v
tmap {s = (Z :: xs)} f v = []
tmap {s = ((S k) :: xs)} f v = tmap f v
--tmap {s=[]} f v = f v
--tmap {s=(x::xs)} f vs = tmap f vs
-- tzipWith : (f : a -> b -> c) -> (tensor s a) -> (tensor s b) -> tensor s c
-- tzipWith {s=[]} f x y = f x y
-- tzipWith {s=[l]} f x y = zipWith f x y
-- tzipWith {s=(l::ls)} f (x::xs) y = ?hole
-- (tzipWith {s=ls} f x y) :: (tzipWith {s=((l-1)::ls)} f xs ys)
infixl 9 #*
(#*) : Num a => tensor [y,x] a -> tensor [x] a -> tensor [y] a
--mat #* vec = map (dot vec) mat
infixl 9 #+
(#+) : Num a => tensor s a -> tensor s a -> tensor s a
--(#+) {s=[]} t1 t2 = t1 + t2
--(#+) {s=(Z::xs)} (t::ts) t2 = ?hole
-- data Index : Vect n Nat -> Type where
-- Here : Index []
-- At : Fin m -> Index ms -> Index (m :: ms)
--
-- index : Index ms -> tensor ms a -> a
-- index Here a = a
-- index (At k i) v = index i $ index k v
interface Layer (layer : Vect n Nat -> Vect n Nat -> Type) where
runLayer : tensor i Double
-> layer i o
-> tensor o Double
--
data FullyConnected : Vect n Nat -> Vect n Nat -> Type where
MkFullyConnected : {i: Vect 1 Nat}
-> {o: Vect 1 Nat}
-> (biases : tensor o Double)
-> (weights : tensor (head o :: i) Double)
-> FullyConnected i o
-- Layer FullyConnected where
-- runLayer input (MkFullyConnected {i=[x]} {o=[y]} biases weights) = weights #* input
--
-- lyr : FullyConnected [2] [3]
-- lyr = MkFullyConnected {i=[2]} {o=[3]} [1,2,3] [[1,1],[1,1],[1,1]]