forked from ryanbressler/CloudForest
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathevaluator_test.go
127 lines (109 loc) · 2.86 KB
/
evaluator_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package CloudForest
import (
"fmt"
"strings"
"testing"
"github.com/bmizerany/assert"
)
func setupCategorical() (*Forest, *FeatureMatrix) {
irisreader := strings.NewReader(irislibsvm)
fm := ParseLibSVM(irisreader)
targeti := 0
cattarget := fm.Data[targeti]
config := &ForestConfig{
NSamples: fm.Data[0].Length(),
MTry: 3,
NTrees: 10,
LeafSize: 1,
}
sample := &FeatureMatrix{
Data: make([]Feature, len(fm.Map)),
Map: make(map[string]int),
}
for k, v := range fm.Map {
var feature Feature
if v == 0 {
feature = NewDenseCatFeature(k)
} else {
feature = NewDenseNumFeature(k)
}
sample.Map[k] = v
sample.Data[v] = feature
sample.Data[v].Append(fm.Data[v].GetStr(0))
}
model := GrowRandomForest(fm, cattarget, config)
return model.Forest, sample
}
func setupNumeric() (*Forest, *FeatureMatrix) {
boston := strings.NewReader(boston_housing)
fm := ParseARFF(boston)
target := fm.Data[fm.Map["class"]]
sample := &FeatureMatrix{
Data: make([]Feature, len(fm.Map)),
Map: make(map[string]int),
}
for k, v := range fm.Map {
sample.Map[k] = v
sample.Data[v] = NewDenseNumFeature(k)
sample.Data[v].Append(fm.Data[v].GetStr(0))
}
config := &ForestConfig{
NSamples: target.Length(),
MTry: 4,
NTrees: 20,
LeafSize: 1,
MaxDepth: 4,
InBag: true,
}
model := GrowRandomForest(fm, target, config)
return model.Forest, sample
}
func TestEvaluator(t *testing.T) {
forest, sample := setupNumeric()
predVal := forest.Predict(sample)[0]
evalPW := NewPiecewiseFlatForest(forest)
evalVal := evalPW.EvaluateNum(sample)[0]
assert.Equal(t, fmt.Sprintf("%.4f", predVal), fmt.Sprintf("%.4f", evalVal))
evalCT := NewContiguousFlatForest(forest)
evalVal = evalCT.EvaluateNum(sample)[0]
assert.Equal(t, fmt.Sprintf("%.4f", predVal), fmt.Sprintf("%.4f", evalVal))
}
func TestCatEvaluator(t *testing.T) {
forest, sample := setupCategorical()
pred := forest.PredictCat(sample)[0]
pw := NewPiecewiseFlatForest(forest)
predPW := pw.EvaluateCat(sample)[0]
assert.Equal(t, pred, predPW)
ct := NewContiguousFlatForest(forest)
predCT := ct.EvaluateCat(sample)[0]
assert.Equal(t, predPW, predCT)
}
// BenchmarkPredict-8 100000 12542 ns/op
func BenchmarkPredict(b *testing.B) {
forest, sample := setupNumeric()
b.StartTimer()
for i := 0; i < b.N; i++ {
forest.Predict(sample)
}
b.StopTimer()
}
// BenchmarkFlatForest-8 2000000 821 ns/op
func BenchmarkFlatForest(b *testing.B) {
forest, sample := setupNumeric()
pw := NewPiecewiseFlatForest(forest)
b.StartTimer()
for i := 0; i < b.N; i++ {
pw.EvaluateNum(sample)
}
b.StopTimer()
}
// BenchmarkContiguousForest-8 5000000 339 ns/op
func BenchmarkContiguousForest(b *testing.B) {
forest, sample := setupNumeric()
ct := NewContiguousFlatForest(forest)
b.StartTimer()
for i := 0; i < b.N; i++ {
ct.EvaluateNum(sample)
}
b.StopTimer()
}