Skip to content

Commit

Permalink
Merge pull request #1271 from borglab/feature/nonlinear-incremental
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Aug 22, 2022
2 parents 2a974a4 + ac20cff commit 4c9c106
Show file tree
Hide file tree
Showing 5 changed files with 572 additions and 500 deletions.
27 changes: 26 additions & 1 deletion gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,36 @@ void GaussianMixture::print(const std::string &s,
"", [&](Key k) { return formatter(k); },
[&](const GaussianConditional::shared_ptr &gf) -> std::string {
RedirectCout rd;
if (!gf->empty())
if (gf && !gf->empty())
gf->print("", formatter);
else
return {"nullptr"};
return rd.str();
});
}

/* *******************************************************************************/
void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
// Functional which loops over all assignments and create a set of
// GaussianConditionals
auto pruner = [&decisionTree](
const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value
DiscreteValues values(choices);

if (decisionTree(values) == 0.0) {
// empty aka null pointer
boost::shared_ptr<GaussianConditional> null;
return null;
} else {
return conditional;
}
};

auto pruned_conditionals = conditionals_.apply(pruner);
conditionals_.root_ = pruned_conditionals.root_;
}

} // namespace gtsam
12 changes: 11 additions & 1 deletion gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/inference/Conditional.h>
Expand Down Expand Up @@ -121,7 +122,7 @@ class GTSAM_EXPORT GaussianMixture
/// Test equality with base HybridFactor
bool equals(const HybridFactor &lf, double tol = 1e-9) const override;

/* print utility */
/// Print utility
void print(
const std::string &s = "GaussianMixture\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
Expand All @@ -131,6 +132,15 @@ class GTSAM_EXPORT GaussianMixture
/// Getter for the underlying Conditionals DecisionTree
const Conditionals &conditionals();

/**
* @brief Prune the decision tree of Gaussian factors as per the discrete
* `decisionTree`.
*
* @param decisionTree A pruned decision tree of discrete keys where the
* leaves are probabilities.
*/
void prune(const DecisionTreeFactor &decisionTree);

/**
* @brief Merge the Gaussian Factor Graphs in `this` and `sum` while
* maintaining the decision tree structure.
Expand Down
54 changes: 49 additions & 5 deletions gtsam/hybrid/HybridGaussianISAM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ HybridGaussianISAM::HybridGaussianISAM(const HybridBayesTree& bayesTree)
void HybridGaussianISAM::updateInternal(
const HybridGaussianFactorGraph& newFactors,
HybridBayesTree::Cliques* orphans,
const boost::optional<Ordering>& ordering,
const HybridBayesTree::Eliminate& function) {
// Remove the contaminated part of the Bayes tree
BayesNetType bn;
Expand Down Expand Up @@ -78,12 +79,19 @@ void HybridGaussianISAM::updateInternal(

// Get an ordering where the new keys are eliminated last
const VariableIndex index(factors);
const Ordering ordering = Ordering::ColamdConstrainedLast(
index, KeyVector(newKeysDiscreteLast.begin(), newKeysDiscreteLast.end()),
true);
Ordering elimination_ordering;
if (ordering) {
elimination_ordering = *ordering;
} else {
elimination_ordering = Ordering::ColamdConstrainedLast(
index,
KeyVector(newKeysDiscreteLast.begin(), newKeysDiscreteLast.end()),
true);
}

// eliminate all factors (top, added, orphans) into a new Bayes tree
auto bayesTree = factors.eliminateMultifrontal(ordering, function, index);
HybridBayesTree::shared_ptr bayesTree =
factors.eliminateMultifrontal(elimination_ordering, function, index);

// Re-add into Bayes tree data structures
this->roots_.insert(this->roots_.end(), bayesTree->roots().begin(),
Expand All @@ -93,9 +101,45 @@ void HybridGaussianISAM::updateInternal(

/* ************************************************************************* */
void HybridGaussianISAM::update(const HybridGaussianFactorGraph& newFactors,
const boost::optional<Ordering>& ordering,
const HybridBayesTree::Eliminate& function) {
Cliques orphans;
this->updateInternal(newFactors, &orphans, function);
this->updateInternal(newFactors, &orphans, ordering, function);
}

void HybridGaussianISAM::prune(const Key& root, const size_t maxNrLeaves) {
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
this->clique(root)->conditional()->inner());
DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDiscreteFactor.root_;

std::vector<gtsam::Key> prunedKeys;
for (auto&& clique : nodes()) {
// The cliques can be repeated for each frontal so we record it in
// prunedKeys and check if we have already pruned a particular clique.
if (std::find(prunedKeys.begin(), prunedKeys.end(), clique.first) !=
prunedKeys.end()) {
continue;
}

// Add all the keys of the current clique to be pruned to prunedKeys
for (auto&& key : clique.second->conditional()->frontals()) {
prunedKeys.push_back(key);
}

// Convert parents() to a KeyVector for comparison
KeyVector parents;
for (auto&& parent : clique.second->conditional()->parents()) {
parents.push_back(parent);
}

if (parents == decisionTree->keys()) {
auto gaussianMixture = boost::dynamic_pointer_cast<GaussianMixture>(
clique.second->conditional()->inner());

gaussianMixture->prune(prunedDiscreteFactor);
}
}
}

} // namespace gtsam
10 changes: 10 additions & 0 deletions gtsam/hybrid/HybridGaussianISAM.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class GTSAM_EXPORT HybridGaussianISAM : public ISAM<HybridBayesTree> {
void updateInternal(
const HybridGaussianFactorGraph& newFactors,
HybridBayesTree::Cliques* orphans,
const boost::optional<Ordering>& ordering = boost::none,
const HybridBayesTree::Eliminate& function =
HybridBayesTree::EliminationTraitsType::DefaultEliminate);

Expand All @@ -59,8 +60,17 @@ class GTSAM_EXPORT HybridGaussianISAM : public ISAM<HybridBayesTree> {
* @param function Elimination function.
*/
void update(const HybridGaussianFactorGraph& newFactors,
const boost::optional<Ordering>& ordering = boost::none,
const HybridBayesTree::Eliminate& function =
HybridBayesTree::EliminationTraitsType::DefaultEliminate);

/**
* @brief
*
* @param root The root key in the discrete conditional decision tree.
* @param maxNumberLeaves
*/
void prune(const Key& root, const size_t maxNumberLeaves);
};

/// traits
Expand Down
Loading

0 comments on commit 4c9c106

Please sign in to comment.