Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid Improvements - II #1294

Merged
merged 7 commits into from
Oct 3, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
HybridConditional(boost::make_shared<DiscreteConditional>(key, table)));
}

using Base::push_back;

/// Get a specific Gaussian mixture by index `i`.
GaussianMixture::shared_ptr atMixture(size_t i) const;

Expand Down
22 changes: 22 additions & 0 deletions gtsam/hybrid/HybridFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,28 @@ class HybridFactorGraph : public FactorGraph<HybridFactor> {
push_hybrid(p);
}
}

/// Get all the discrete keys in the factor graph.
const KeySet allDiscreteKeys() const {
KeySet discrete_keys;
for (auto& factor : factors_) {
for (const DiscreteKey& k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
}
}
return discrete_keys;
}

/// Get all the continuous keys in the factor graph.
const KeySet allContinuousKeys() const {
KeySet keys;
for (auto& factor : factors_) {
for (const Key& key : factor->continuousKeys()) {
keys.insert(key);
}
}
return keys;
}
};

} // namespace gtsam
24 changes: 1 addition & 23 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,31 +404,9 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
}

/* ************************************************************************ */
const KeySet HybridGaussianFactorGraph::getDiscreteKeys() const {
KeySet discrete_keys;
for (auto &factor : factors_) {
for (const DiscreteKey &k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
}
}
return discrete_keys;
}

/* ************************************************************************ */
const KeySet HybridGaussianFactorGraph::getContinuousKeys() const {
KeySet keys;
for (auto &factor : factors_) {
for (const Key &key : factor->continuousKeys()) {
keys.insert(key);
}
}
return keys;
}

/* ************************************************************************ */
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
KeySet discrete_keys = getDiscreteKeys();
KeySet discrete_keys = allDiscreteKeys();
for (auto &factor : factors_) {
for (const DiscreteKey &k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
Expand Down
6 changes: 0 additions & 6 deletions gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
}
}

/// Get all the discrete keys in the factor graph.
const KeySet getDiscreteKeys() const;

/// Get all the continuous keys in the factor graph.
const KeySet getContinuousKeys() const;

/**
* @brief Return a Colamd constrained ordering where the discrete keys are
* eliminated after the continuous keys.
Expand Down
6 changes: 6 additions & 0 deletions gtsam/hybrid/HybridGaussianISAM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,23 +62,29 @@ void HybridGaussianISAM::updateInternal(
for (const sharedClique& orphan : *orphans)
factors += boost::make_shared<BayesTreeOrphanWrapper<Node> >(orphan);

// Get all the discrete keys from the factors
KeySet allDiscrete;
for (auto& factor : factors) {
for (auto& k : factor->discreteKeys()) {
allDiscrete.insert(k.first);
}
}

// Create KeyVector with continuous keys followed by discrete keys.
KeyVector newKeysDiscreteLast;
// Insert continuous keys first.
for (auto& k : newFactorKeys) {
if (!allDiscrete.exists(k)) {
newKeysDiscreteLast.push_back(k);
}
}
// Insert discrete keys at the end
std::copy(allDiscrete.begin(), allDiscrete.end(),
std::back_inserter(newKeysDiscreteLast));

// Get an ordering where the new keys are eliminated last
const VariableIndex index(factors);

Ordering elimination_ordering;
if (ordering) {
elimination_ordering = *ordering;
Expand Down
15 changes: 15 additions & 0 deletions gtsam/hybrid/tests/testHybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,21 @@ TEST(HybridBayesNet, Creation) {
EXPECT(df.equals(expected));
}

/* ****************************************************************************/
// Test adding a bayes net to another one.
TEST(HybridBayesNet, Add) {
HybridBayesNet bayesNet;

bayesNet.add(Asia, "99/1");

DiscreteConditional expected(Asia, "99/1");

HybridBayesNet other;
other.push_back(bayesNet);
EXPECT(bayesNet.equals(other));
}


/* ****************************************************************************/
// Test choosing an assignment of conditionals
TEST(HybridBayesNet, Choose) {
Expand Down