-
Notifications
You must be signed in to change notification settings - Fork 26
/
test.py
113 lines (90 loc) · 3.63 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import unittest
import numpy as np
from sumproduct import Variable, Factor, FactorGraph, Mu
class SimpleGraph(unittest.TestCase):
"""
This is the graph pictured in the readme.
Learn more about this graphical model from exercise 8.10
and figure 8.54 in Christopher Bishop's Pattern Recognition
and Machine Learning
"""
def createSimpleGraph(self):
# create the orphaned nodes
node_names = ['x1', 'x2', 'x3', 'x4']
dims = [2, 3, 2, 2]
# pad array just so that reference like x[1] later is easier to read
x = [None] + [Variable(node_names[i], dims[i]) for i in range(4)]
f3 = Factor('f3', np.array([0.2, 0.8]))
f4 = Factor('f4', np.array([0.5, 0.5]))
# first index is x3, second index is x4, third index is x2
# looking at it like: arr[0][0][0]
f234 = Factor('f234', np.array([
[
[0.3, 0.5, 0.2], [0.1, 0.1, 0.8]
], [
[0.9, 0.05, 0.05], [0.2, 0.7, 0.1]
]
]))
# first index is x2
f12 = Factor('f12', np.array([[0.8, 0.2], [0.2, 0.8], [0.5, 0.5]]))
# attach nodes to graph in right order (connections matching
# factor's potential's dimensions order)
g = FactorGraph(x[3], silent=True)
g.append('x3', f234)
g.append('f234', x[4])
g.append('f234', x[2])
g.append('x2', f12)
g.append('f12', x[1])
g.append('x3', f3)
g.append('x4', f4)
return g
def setUp(self):
self.g = self.createSimpleGraph()
def testTwoIndependentInstances(self):
g1 = self.createSimpleGraph()
g2 = FactorGraph()
self.assertTrue(len(g1.nodes))
self.assertTrue(len(g2.nodes) == 0)
def testCustomErrorFunction(self):
def func(m1, m2):
return sum([sum(np.absolute(m1[k] - m2[k])) for k in m1.keys()])
self.g.compute_marginals(error_fun=func)
def testSumProductInference(self):
self.g.compute_marginals()
self.assertTrue(
np.allclose(self.g.nodes['x1'].marginal(), np.array([0.536, 0.464
])))
self.assertTrue(
np.allclose(self.g.nodes['x2'].marginal(), np.array([0.48, 0.36,
0.16])))
self.assertTrue(
np.allclose(self.g.nodes['x3'].marginal(), np.array([0.2, 0.8])))
self.assertTrue(
np.allclose(self.g.nodes['x4'].marginal(), np.array([0.5, 0.5])))
def testBruteForceInference(self):
self.g.brute_force()
self.assertTrue(
np.allclose(self.g.nodes['x1'].bfmarginal, np.array([0.536, 0.464
])))
self.assertTrue(
np.allclose(self.g.nodes['x2'].bfmarginal, np.array([0.48, 0.36,
0.16])))
self.assertTrue(
np.allclose(self.g.nodes['x3'].bfmarginal, np.array([0.2, 0.8])))
self.assertTrue(
np.allclose(self.g.nodes['x4'].bfmarginal, np.array([0.5, 0.5])))
class InboxToMarginal(unittest.TestCase):
def setUp(self):
node = Variable('bit', 2)
uniform = Mu(None, np.array([0.5, 0.5]))
point = Mu(None, np.array([1.0, 0.0]))
node.deliver(2, uniform)
node.deliver(2, point)
self.node = node
def testFewHarshProbabilities(self):
self.assertTrue(
np.allclose(self.node.marginal(), np.array([1.0, 0.0])))
def main():
unittest.main()
if __name__ == '__main__':
main()