Skip to content

Commit

Permalink
Merge pull request #22 from varunagrawal/hybridbayesnet/private-members
Browse files Browse the repository at this point in the history
HybridBayesNet private members
  • Loading branch information
varunagrawal authored Feb 17, 2022
2 parents 4fe4c2f + 9690bcb commit f5bb790
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 71 deletions.
54 changes: 40 additions & 14 deletions gtsam/hybrid/IncrementalHybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,18 @@
#include <algorithm>
#include <unordered_set>

void gtsam::IncrementalHybrid::update(gtsam::GaussianHybridFactorGraph graph,
const gtsam::Ordering &ordering,
boost::optional<size_t> maxNrLeaves) {
namespace gtsam {

/* ************************************************************************* */
void IncrementalHybrid::update(GaussianHybridFactorGraph graph,
const Ordering &ordering,
boost::optional<size_t> maxNrLeaves) {
// if we are not at the first iteration
if (hybridBayesNet_) {
if (!hybridBayesNet_.empty()) {
// We add all relevant conditional mixtures on the last continuous variable
// in the previous `hybridBayesNet` to the graph
std::unordered_set<Key> allVars(ordering.begin(), ordering.end());
for (auto &&conditional : *hybridBayesNet_) {
for (auto &&conditional : hybridBayesNet_) {
for (auto &key : conditional->frontals()) {
if (allVars.find(key) != allVars.end()) {
if (auto gf =
Expand All @@ -44,28 +47,26 @@ void gtsam::IncrementalHybrid::update(gtsam::GaussianHybridFactorGraph graph,
}
}
}
} else {
// Initialize an empty HybridBayesNet
hybridBayesNet_ = boost::make_shared<HybridBayesNet>();
}

// Eliminate partially.
HybridBayesNet::shared_ptr bayesNetFragment;
std::tie(bayesNetFragment, remainingFactorGraph_) =
graph.eliminatePartialSequential(ordering);
auto result = graph.eliminatePartialSequential(ordering);
bayesNetFragment = result.first;
remainingFactorGraph_ = *result.second;

// Add the partial bayes net to the posterior bayes net.
hybridBayesNet_->push_back<HybridBayesNet>(*bayesNetFragment);
hybridBayesNet_.push_back<HybridBayesNet>(*bayesNetFragment);

// Prune
if (maxNrLeaves) {
const auto N = *maxNrLeaves;

const auto lastDensity =
boost::dynamic_pointer_cast<GaussianMixture>(hybridBayesNet_->back());
boost::dynamic_pointer_cast<GaussianMixture>(hybridBayesNet_.back());

auto discreteFactor = boost::dynamic_pointer_cast<DecisionTreeFactor>(
remainingFactorGraph_->discreteGraph().at(0));
remainingFactorGraph_.discreteGraph().at(0));

// Let's assume that the structure of the last discrete density will be the
// same as the last continuous
Expand Down Expand Up @@ -112,7 +113,32 @@ void gtsam::IncrementalHybrid::update(gtsam::GaussianHybridFactorGraph graph,
GaussianMixture::Factors prunedConditionalsTree(lastDensity->discreteKeys(),
prunedConditionals);

hybridBayesNet_->atGaussian(hybridBayesNet_->size() - 1)->factors_ =
hybridBayesNet_.atGaussian(hybridBayesNet_.size() - 1)->factors_ =
prunedConditionalsTree;
}
}

/* ************************************************************************* */
GaussianMixture::shared_ptr IncrementalHybrid::gaussianMixture(
size_t index) const {
return boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet_.at(index));
}

/* ************************************************************************* */
const DiscreteFactorGraph &IncrementalHybrid::remainingDiscreteGraph() const {
return remainingFactorGraph_.discreteGraph();
}

/* ************************************************************************* */
const HybridBayesNet &IncrementalHybrid::hybridBayesNet() const {
return hybridBayesNet_;
}

/* ************************************************************************* */
const GaussianHybridFactorGraph &IncrementalHybrid::remainingFactorGraph()
const {
return remainingFactorGraph_;
}

} // namespace gtsam
49 changes: 33 additions & 16 deletions gtsam/hybrid/IncrementalHybrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,55 @@
* @date December 2021
*/

#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/GaussianHybridFactorGraph.h>
#include <gtsam/hybrid/HybridBayesNet.h>

namespace gtsam {

class IncrementalHybrid {
private:
HybridBayesNet hybridBayesNet_;
GaussianHybridFactorGraph remainingFactorGraph_;

public:

HybridBayesNet::shared_ptr hybridBayesNet_;
GaussianHybridFactorGraph::shared_ptr remainingFactorGraph_;

/**
* Given new factors, perform an incremental update.
* The relevant densities in the `hybridBayesNet` will be added to the input
* graph (fragment), and then eliminated according to the `ordering` presented.
* The remaining factor graph contains Gaussian mixture factors that are not
* connected to the variables in the ordering, or a single discrete factor on
* all discrete keys, plus all discrete factors in the original graph.
* graph (fragment), and then eliminated according to the `ordering`
* presented. The remaining factor graph contains Gaussian mixture factors
* that are not connected to the variables in the ordering, or a single
* discrete factor on all discrete keys, plus all discrete factors in the
* original graph.
*
* \note If maxComponents is given, we look at the discrete factor resulting
* from this elimination, and prune it and the Gaussian components corresponding
* to the pruned choices.
* from this elimination, and prune it and the Gaussian components
* corresponding to the pruned choices.
*
* @param graph The new factors, should be linear only
* @param ordering The ordering for elimination, only continuous vars are allowed
* @param maxNrLeaves The maximum number of leaves in the new discrete factor, if applicable
* @param ordering The ordering for elimination, only continuous vars are
* allowed
* @param maxNrLeaves The maximum number of leaves in the new discrete factor,
* if applicable
*/
void update(GaussianHybridFactorGraph graph,
const Ordering &ordering,
void update(GaussianHybridFactorGraph graph, const Ordering& ordering,
boost::optional<size_t> maxNrLeaves = boost::none);

};
/// Get the Gaussian Mixture from the Bayes Net posterior at `index`.
GaussianMixture::shared_ptr gaussianMixture(size_t index) const;

/// Return the discrete graph after continuous variables have been eliminated.
const DiscreteFactorGraph& remainingDiscreteGraph() const;

/// Return the Bayes Net posterior.
const HybridBayesNet& hybridBayesNet() const;

/**
* @brief Return the leftover factor graph after the last update with the
* specified ordering.
*
* @return GaussianHybridFactorGraph
*/
const GaussianHybridFactorGraph& remainingFactorGraph() const;
};

}; // namespace gtsam
78 changes: 37 additions & 41 deletions gtsam/hybrid/tests/testIncrementalHybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,18 @@ TEST_UNSAFE(DCGaussianElimination, Incremental_inference) {

incrementalHybrid.update(graph1, ordering);

auto hybridBayesNet = incrementalHybrid.hybridBayesNet_;
CHECK(hybridBayesNet);
EXPECT_LONGS_EQUAL(2, hybridBayesNet->size());
EXPECT(hybridBayesNet->at(0)->frontals() == KeyVector{X(1)});
EXPECT(hybridBayesNet->at(0)->parents() == KeyVector({X(2), M(1)}));
EXPECT(hybridBayesNet->at(1)->frontals() == KeyVector{X(2)});
EXPECT(hybridBayesNet->at(1)->parents() == KeyVector({M(1)}));

auto remainingFactorGraph = incrementalHybrid.remainingFactorGraph_;
CHECK(remainingFactorGraph);
EXPECT_LONGS_EQUAL(1, remainingFactorGraph->size());
auto hybridBayesNet = incrementalHybrid.hybridBayesNet();
EXPECT_LONGS_EQUAL(2, hybridBayesNet.size());
EXPECT(hybridBayesNet.at(0)->frontals() == KeyVector{X(1)});
EXPECT(hybridBayesNet.at(0)->parents() == KeyVector({X(2), M(1)}));
EXPECT(hybridBayesNet.at(1)->frontals() == KeyVector{X(2)});
EXPECT(hybridBayesNet.at(1)->parents() == KeyVector({M(1)}));

auto remainingFactorGraph = incrementalHybrid.remainingFactorGraph();
EXPECT_LONGS_EQUAL(1, remainingFactorGraph.size());

auto discreteFactor_m1 = *dynamic_pointer_cast<DecisionTreeFactor>(
remainingFactorGraph->discreteGraph().at(0));
remainingFactorGraph.discreteGraph().at(0));
EXPECT(discreteFactor_m1.keys() == KeyVector({M(1)}));

GaussianHybridFactorGraph graph2;
Expand All @@ -89,20 +87,20 @@ TEST_UNSAFE(DCGaussianElimination, Incremental_inference) {

incrementalHybrid.update(graph2, ordering2);

auto hybridBayesNet2 = incrementalHybrid.hybridBayesNet_;
CHECK(hybridBayesNet2);
EXPECT_LONGS_EQUAL(4, hybridBayesNet2->size());
EXPECT(hybridBayesNet2->at(2)->frontals() == KeyVector{X(2)});
EXPECT(hybridBayesNet2->at(2)->parents() == KeyVector({X(3), M(2), M(1)}));
EXPECT(hybridBayesNet2->at(3)->frontals() == KeyVector{X(3)});
EXPECT(hybridBayesNet2->at(3)->parents() == KeyVector({M(2), M(1)}));
auto hybridBayesNet2 = incrementalHybrid.hybridBayesNet();

EXPECT_LONGS_EQUAL(4, hybridBayesNet2.size());
// hybridBayesNet2.print();
EXPECT(hybridBayesNet2.at(2)->frontals() == KeyVector{X(2)});
EXPECT(hybridBayesNet2.at(2)->parents() == KeyVector({X(3), M(2), M(1)}));
EXPECT(hybridBayesNet2.at(3)->frontals() == KeyVector{X(3)});
EXPECT(hybridBayesNet2.at(3)->parents() == KeyVector({M(2), M(1)}));

auto remainingFactorGraph2 = incrementalHybrid.remainingFactorGraph_;
CHECK(remainingFactorGraph2);
EXPECT_LONGS_EQUAL(1, remainingFactorGraph2->size());
auto remainingFactorGraph2 = incrementalHybrid.remainingFactorGraph();
EXPECT_LONGS_EQUAL(1, remainingFactorGraph2.size());

auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(
remainingFactorGraph2->discreteGraph().at(0));
remainingFactorGraph2.discreteGraph().at(0));
EXPECT(discreteFactor->keys() == KeyVector({M(2), M(1)}));

ordering.clear();
Expand All @@ -117,15 +115,15 @@ TEST_UNSAFE(DCGaussianElimination, Incremental_inference) {
switching.linearizedFactorGraph.eliminatePartialSequential(ordering);

// The densities on X(1) should be the same
EXPECT(assert_equal(*(hybridBayesNet->atGaussian(0)),
EXPECT(assert_equal(*(hybridBayesNet.atGaussian(0)),
*(expectedHybridBayesNet->atGaussian(0))));

// The densities on X(2) should be the same
EXPECT(assert_equal(*(hybridBayesNet2->atGaussian(2)),
EXPECT(assert_equal(*(hybridBayesNet2.atGaussian(2)),
*(expectedHybridBayesNet->atGaussian(1))));

// The densities on X(3) should be the same
EXPECT(assert_equal(*(hybridBayesNet2->atGaussian(3)),
EXPECT(assert_equal(*(hybridBayesNet2.atGaussian(3)),
*(expectedHybridBayesNet->atGaussian(2))));

// we only do the manual continuous elimination for 0,0
Expand All @@ -140,7 +138,7 @@ TEST_UNSAFE(DCGaussianElimination, Incremental_inference) {
dynamic_pointer_cast<DCGaussianMixtureFactor>(graph2.dcGraph().at(0));
gf.add(dcMixture->factors()(m00));
auto x2_mixed =
boost::dynamic_pointer_cast<GaussianMixture>(hybridBayesNet->at(1));
boost::dynamic_pointer_cast<GaussianMixture>(hybridBayesNet.at(1));
gf.add(x2_mixed->factors()(m00));
auto result_gf = gf.eliminateSequential();
return gf.probPrime(result_gf->optimize());
Expand Down Expand Up @@ -226,12 +224,11 @@ TEST(DCGaussianElimination, Approx_inference) {
1 1 0 Leaf 0.611 *
1 1 1 Leaf 1 *
*/
auto remainingFactorGraph = incrementalHybrid.remainingFactorGraph_;
CHECK(remainingFactorGraph);
EXPECT_LONGS_EQUAL(1, remainingFactorGraph->size());
auto remainingFactorGraph = incrementalHybrid.remainingFactorGraph();
EXPECT_LONGS_EQUAL(1, remainingFactorGraph.size());

auto discreteFactor_m1 = *dynamic_pointer_cast<DecisionTreeFactor>(
remainingFactorGraph->discreteGraph().at(0));
remainingFactorGraph.discreteGraph().at(0));
EXPECT(discreteFactor_m1.keys() == KeyVector({M(3), M(2), M(1)}));

// Check number of elements equal to zero
Expand All @@ -246,16 +243,15 @@ TEST(DCGaussianElimination, Approx_inference) {
* factor 2: [x3 | x4 m3 m2 m1 ], 8 components
* factor 3: [x4 | m3 m2 m1 ], 8 components
*/
auto hybridBayesNet = incrementalHybrid.hybridBayesNet_;
auto hybridBayesNet = incrementalHybrid.hybridBayesNet();

CHECK(hybridBayesNet);
EXPECT_LONGS_EQUAL(4, hybridBayesNet->size());
EXPECT_LONGS_EQUAL(2, hybridBayesNet->atGaussian(0)->nrComponents());
EXPECT_LONGS_EQUAL(4, hybridBayesNet->atGaussian(1)->nrComponents());
EXPECT_LONGS_EQUAL(8, hybridBayesNet->atGaussian(2)->nrComponents());
EXPECT_LONGS_EQUAL(5, hybridBayesNet->atGaussian(3)->nrComponents());
EXPECT_LONGS_EQUAL(4, hybridBayesNet.size());
EXPECT_LONGS_EQUAL(2, hybridBayesNet.atGaussian(0)->nrComponents());
EXPECT_LONGS_EQUAL(4, hybridBayesNet.atGaussian(1)->nrComponents());
EXPECT_LONGS_EQUAL(8, hybridBayesNet.atGaussian(2)->nrComponents());
EXPECT_LONGS_EQUAL(5, hybridBayesNet.atGaussian(3)->nrComponents());

auto &lastDensity = *(hybridBayesNet->atGaussian(3));
auto &lastDensity = *(hybridBayesNet.atGaussian(3));
auto &unprunedLastDensity = *(unprunedHybridBayesNet->atGaussian(3));
std::vector<std::pair<DiscreteValues, double>> assignments =
discreteFactor_m1.enumerate();
Expand Down Expand Up @@ -302,7 +298,7 @@ TEST_UNSAFE(DCGaussianElimination, Incremental_approximate) {
size_t maxComponents = 5;
incrementalHybrid.update(graph1, ordering, maxComponents);

auto &actualBayesNet1 = *incrementalHybrid.hybridBayesNet_;
auto actualBayesNet1 = incrementalHybrid.hybridBayesNet();
CHECK_EQUAL(4, actualBayesNet1.size());
EXPECT_LONGS_EQUAL(2, actualBayesNet1.atGaussian(0)->nrComponents());
EXPECT_LONGS_EQUAL(4, actualBayesNet1.atGaussian(1)->nrComponents());
Expand All @@ -319,7 +315,7 @@ TEST_UNSAFE(DCGaussianElimination, Incremental_approximate) {

incrementalHybrid.update(graph2, ordering2, maxComponents);

auto &actualBayesNet = *incrementalHybrid.hybridBayesNet_;
auto actualBayesNet = incrementalHybrid.hybridBayesNet();
CHECK_EQUAL(2, actualBayesNet.size());
EXPECT_LONGS_EQUAL(10, actualBayesNet.atGaussian(0)->nrComponents());
EXPECT_LONGS_EQUAL(5, actualBayesNet.atGaussian(1)->nrComponents());
Expand Down

0 comments on commit f5bb790

Please sign in to comment.