Skip to content

Commit

Permalink
fixed bug of parsing and em
Browse files Browse the repository at this point in the history
  • Loading branch information
eelxpeng committed Sep 18, 2018
1 parent 5eec490 commit f30ddcd
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pyltm/io/bifparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def parse(self, filename):
self.net = None
for inst in tree.children:
self.read_tree(inst)
print(str(self.net))

return self.net

def read_tree(self, t):
Expand Down
31 changes: 17 additions & 14 deletions pyltm/learner/em_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,18 @@ def initializeSufficientStatistics(self, batch_size):
ctp = NaturalCliqueTreePropagation(self._model)
ctp.initializePotentials()
tree = ctp.cliqueTree()
sufficientStatistics = [None]*len(tree.nodes)
batchSufficientStatistics = [None]*len(tree.nodes)
for i in range(len(tree.nodes)):
if isinstance(tree.nodes[i], DiscreteClique):
sufficientStatistics[i] = DiscreteCliqueSufficientStatistics(tree.nodes[i], batch_size)
batchSufficientStatistics[i] = DiscreteCliqueSufficientStatistics(tree.nodes[i], batch_size)
elif isinstance(tree.nodes[i], MixedClique):
sufficientStatistics[i] = MixedCliqueSufficientStatistics(tree.nodes[i], batch_size)
batchSufficientStatistics[i] = MixedCliqueSufficientStatistics(tree.nodes[i], batch_size)
cliques = tree.cliques
sufficientStatistics = [None]*len(cliques)
batchSufficientStatistics = [None]*len(cliques)
for i in range(len(cliques)):
if isinstance(cliques[i], DiscreteClique):
sufficientStatistics[i] = DiscreteCliqueSufficientStatistics(cliques[i], batch_size)
batchSufficientStatistics[i] = DiscreteCliqueSufficientStatistics(cliques[i], batch_size)
elif isinstance(cliques[i], MixedClique):
sufficientStatistics[i] = MixedCliqueSufficientStatistics(cliques[i], batch_size)
batchSufficientStatistics[i] = MixedCliqueSufficientStatistics(cliques[i], batch_size)
else:
raise ValueError("invalid clique type")
raise Exception("unknown type of clique")
return sufficientStatistics, batchSufficientStatistics

def reset(self):
Expand All @@ -51,6 +52,9 @@ def stepwise_e_step(self, data, varNames):
varNames: list of string
'''
ctp = NaturalCliqueTreePropagation(self._model)
tree = ctp.cliqueTree()
cliques = tree.cliques

# set up evidence
datacase = ContinuousDatacase.create(varNames)
datacase.synchronize(self._model)
Expand All @@ -61,15 +65,14 @@ def stepwise_e_step(self, data, varNames):
ctp.use(evidence)
ctp.propagate()

for j in range(len(ctp.cliqueTree().nodes)):
self.batchSufficientStatistics[j].add(ctp.cliqueTree().nodes[j].potential)
for j in range(len(cliques)):
self.batchSufficientStatistics[j].add(cliques[j].potential)

# construct variable to statisticMap
variableStatisticMap = dict()
tree = ctp.cliqueTree()
for node in self._model.nodes:
clique = tree.getClique(node.variable)
index = tree.nodes.index(clique)
index = cliques.index(clique)
variableStatisticMap[node.variable] = (self.sufficientStatistics[index], self.batchSufficientStatistics[index])
return variableStatisticMap

Expand Down
2 changes: 1 addition & 1 deletion pyltm/model/potential/cptpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def divide(self, other):
if isinstance(other, float) or isinstance(other, int):
self._parameter.prob[:] = self._parameter.prob / other
else:
self._parameter.prob[:] = self._parameter.prob / other._parameter.prob
self._parameter.prob[:] = self._parameter.prob / np.maximum(other._parameter.prob, 1e-10)

@property
def parameter(self):
Expand Down

0 comments on commit f30ddcd

Please sign in to comment.