forked from ryanbressler/CloudForest
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcatballotbox.go
128 lines (111 loc) · 2.81 KB
/
catballotbox.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package CloudForest
import (
"sort"
"sync"
)
//CatBallot is used insideof CatBallotBox to record catagorical votes in a thread safe
//manner.
type CatBallot struct {
Mutex sync.RWMutex
Map map[int]float64
}
//NewCatBallot returns a pointer to an initalized CatBallot with a 0 size Map.
func NewCatBallot() (cb *CatBallot) {
cb = new(CatBallot)
cb.Map = make(map[int]float64, 0)
return
}
//CatBallotBox keeps track of votes by trees in a thread safe manner.
type CatBallotBox struct {
*CatMap
Box []*CatBallot
}
//NewCatBallotBox builds a new ballot box for the number of cases specified by "size".
func NewCatBallotBox(size int) *CatBallotBox {
bb := CatBallotBox{
CatMap: NewCatMap(),
Box: make([]*CatBallot, 0, size),
}
for i := 0; i < size; i++ {
bb.Box = append(bb.Box, NewCatBallot())
}
return &bb
}
//Vote registers a vote that case "casei" should be predicted to be the
//category "pred".
func (bb *CatBallotBox) Vote(casei int, pred string, weight float64) {
predn := bb.CatToNum(pred)
bb.Box[casei].Mutex.Lock()
if _, ok := bb.Box[casei].Map[predn]; !ok {
bb.Box[casei].Map[predn] = 0
}
bb.Box[casei].Map[predn] = bb.Box[casei].Map[predn] + weight
bb.Box[casei].Mutex.Unlock()
}
//Tally tallies the votes for the case specified by i as
//if it is a Categorical or boolean feature. Ie it returns the mode
//(the most frequent value) of all votes.
func (bb *CatBallotBox) Tally(i int) (predicted string) {
var predictedn int
var maxVote float64
var ties []int
bb.Box[i].Mutex.RLock()
for k, v := range bb.Box[i].Map {
if v > maxVote {
predictedn = k
maxVote = v
ties = nil
}
// keep track of the ties so that our predictions
// are deterministic
if v == maxVote {
ties = append(ties, k)
}
}
bb.Box[i].Mutex.RUnlock()
// if there is a tie in the predictions,
// then pick the smaller key
if len(ties) > 1 {
sort.Ints(ties)
predictedn = ties[0]
}
if maxVote > 0 {
predicted = bb.Back[predictedn]
} else {
predicted = "NA"
}
return
}
/*
TallyError returns the balanced classification error for categorical features.
1 - sum((sum(Y(xi)=Y'(xi))/|xi|))
where
Y are the labels
Y' are the estimated labels
xi is the set of samples with the ith actual label
Case for which the true category is not known are ignored.
*/
func (bb *CatBallotBox) TallyError(feature Feature) float64 {
catfeature := feature.(CatFeature)
ncats := catfeature.NCats()
correct := make([]int, ncats)
total := make([]int, ncats)
for i := 0; i < feature.Length(); i++ {
value := catfeature.Geti(i)
predicted := bb.Tally(i)
if feature.IsMissing(i) {
continue
}
total[value]++
if catfeature.NumToCat(value) == predicted {
correct[value]++
}
}
var e float64
for i, ncorrect := range correct {
e += float64(ncorrect) / float64(total[i])
}
e /= float64(ncats)
e = 1.0 - e
return e
}