-
Notifications
You must be signed in to change notification settings - Fork 2
/
parser_test.go
66 lines (61 loc) · 1.76 KB
/
parser_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
package parser_test
import (
"testing"
G "gorgonia.org/gorgonia"
"github.com/gorgonia/parser"
"gorgonia.org/tensor"
)
func σ(a *G.Node) *G.Node {
return G.Must(G.Sigmoid(a))
}
func TestParse(t *testing.T) {
g := G.NewGraph()
wfT := tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float32{1, 1, 1, 1}))
wf := G.NewMatrix(g, tensor.Float32, G.WithName("wf"), G.WithShape(2, 2), G.WithValue(wfT))
htprevT := tensor.New(tensor.WithBacking([]float32{1, 1}), tensor.WithShape(2))
htprev := G.NewVector(g, tensor.Float32, G.WithName("ht-1"), G.WithShape(2), G.WithValue(htprevT))
xtT := tensor.New(tensor.WithBacking([]float32{1, 1}), tensor.WithShape(2))
xt := G.NewVector(g, tensor.Float32, G.WithName("xt"), G.WithShape(2), G.WithValue(xtT))
bfT := tensor.New(tensor.WithBacking([]float32{1, 1}), tensor.WithShape(2))
bf := G.NewVector(g, tensor.Float32, G.WithName("bf"), G.WithShape(2), G.WithValue(bfT))
p := parser.NewParser(g)
p.Set(`Wf`, wf)
p.Set(`hₜ₋₁`, htprev)
p.Set(`xₜ`, xt)
p.Set(`bf`, bf)
//result, err := p.Parse(`σ(1*Wf·hₜ₋₁+ Wf·xₜ+ bf)`)
type test struct {
equation string
expected []float32
}
for _, test := range []test{
{
`1*Wf·hₜ₋₁+ Wf·xₜ+ bf`,
[]float32{5, 5},
},
{
`σ(1*Wf·hₜ₋₁+ Wf·xₜ+ bf)`,
[]float32{0.9933072, 0.9933072},
},
{
`tanh(1*Wf·hₜ₋₁+ Wf·xₜ+ bf)`,
[]float32{0.9999092, 0.9999092},
},
} {
result, err := p.Parse(test.equation)
if err != nil {
t.Fatal(err)
}
machine := G.NewLispMachine(g, G.ExecuteFwdOnly())
if err := machine.RunAll(); err != nil {
t.Fatal(err)
}
res := result.Value().Data().([]float32)
if len(res) != 2 {
t.Fail()
}
if res[0] != test.expected[0] || res[1] != test.expected[1] {
t.Fail()
}
}
}