Skip to content

Commit

Permalink
Merge pull request #25 from varunagrawal/fix/remove-duplicate-conditi…
Browse files Browse the repository at this point in the history
…onals
  • Loading branch information
varunagrawal authored Feb 17, 2022
2 parents f5bb790 + 571f1b4 commit fb2a9e7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
13 changes: 13 additions & 0 deletions gtsam/hybrid/IncrementalHybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,27 @@ void IncrementalHybrid::update(GaussianHybridFactorGraph graph,
// in the previous `hybridBayesNet` to the graph
std::unordered_set<Key> allVars(ordering.begin(), ordering.end());
for (auto &&conditional : hybridBayesNet_) {
// Flag indicating if a conditional will be updated due to factors in
// `graph`
bool marked_for_update = false;

for (auto &key : conditional->frontals()) {
if (allVars.find(key) != allVars.end()) {
if (auto gf =
boost::dynamic_pointer_cast<GaussianMixture>(conditional)) {
graph.push_back(gf);
marked_for_update = true;
} else if (auto df = boost::dynamic_pointer_cast<DiscreteConditional>(
conditional)) {
graph.push_back(df);
marked_for_update = true;
}

// If a conditional is due to be updated, we remove if from the
// previous bayes net.
if (marked_for_update) {
auto it = find(hybridBayesNet_.begin(), hybridBayesNet_.end(), conditional);
hybridBayesNet_.erase(it);
}
break;
}
Expand Down
14 changes: 7 additions & 7 deletions gtsam/hybrid/tests/testIncrementalHybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ TEST_UNSAFE(DCGaussianElimination, Incremental_inference) {

auto hybridBayesNet2 = incrementalHybrid.hybridBayesNet();

EXPECT_LONGS_EQUAL(4, hybridBayesNet2.size());
EXPECT_LONGS_EQUAL(3, hybridBayesNet2.size());
// hybridBayesNet2.print();
EXPECT(hybridBayesNet2.at(2)->frontals() == KeyVector{X(2)});
EXPECT(hybridBayesNet2.at(2)->parents() == KeyVector({X(3), M(2), M(1)}));
EXPECT(hybridBayesNet2.at(3)->frontals() == KeyVector{X(3)});
EXPECT(hybridBayesNet2.at(3)->parents() == KeyVector({M(2), M(1)}));
EXPECT(hybridBayesNet2.at(1)->frontals() == KeyVector{X(2)});
EXPECT(hybridBayesNet2.at(1)->parents() == KeyVector({X(3), M(2), M(1)}));
EXPECT(hybridBayesNet2.at(2)->frontals() == KeyVector{X(3)});
EXPECT(hybridBayesNet2.at(2)->parents() == KeyVector({M(2), M(1)}));

auto remainingFactorGraph2 = incrementalHybrid.remainingFactorGraph();
EXPECT_LONGS_EQUAL(1, remainingFactorGraph2.size());
Expand All @@ -119,11 +119,11 @@ TEST_UNSAFE(DCGaussianElimination, Incremental_inference) {
*(expectedHybridBayesNet->atGaussian(0))));

// The densities on X(2) should be the same
EXPECT(assert_equal(*(hybridBayesNet2.atGaussian(2)),
EXPECT(assert_equal(*(hybridBayesNet2.atGaussian(1)),
*(expectedHybridBayesNet->atGaussian(1))));

// The densities on X(3) should be the same
EXPECT(assert_equal(*(hybridBayesNet2.atGaussian(3)),
EXPECT(assert_equal(*(hybridBayesNet2.atGaussian(2)),
*(expectedHybridBayesNet->atGaussian(2))));

// we only do the manual continuous elimination for 0,0
Expand Down

0 comments on commit fb2a9e7

Please sign in to comment.