Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to EliminateDiscrete for max-product #1362

Merged
merged 52 commits into from
Jan 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
efd8eb1
Switch to EliminateDiscrete for max-product
dellaert Dec 31, 2022
dcb07fe
Test eliminate
dellaert Dec 31, 2022
4d3bbf6
HBN::evaluate
dellaert Dec 28, 2022
b772d67
refactoring variables for clarity
dellaert Dec 28, 2022
be8008e
Also print mean if no parents
dellaert Dec 31, 2022
f6f782a
Add static
dellaert Dec 31, 2022
143022c
Tiny Bayes net example
dellaert Dec 31, 2022
df4fb13
fix comment
dellaert Dec 31, 2022
4023e71
continuousSubset
dellaert Dec 31, 2022
c8008cb
tiny FG
dellaert Dec 31, 2022
b463386
Made SumFrontals a method to test
dellaert Dec 31, 2022
92e2a39
Added factor and constant and removed factors method
dellaert Jan 1, 2023
fa76d53
refactored and documented SumFrontals
dellaert Jan 1, 2023
039c9b9
Test SumFrontals
dellaert Jan 1, 2023
526da2c
Add Testable to GraphAndConstant
dellaert Jan 1, 2023
b094953
Fix compile issues after rebase
dellaert Jan 1, 2023
4cb03b3
Fix SumFrontals test
dellaert Jan 1, 2023
7ab4c3e
Change to real test
dellaert Jan 1, 2023
6483130
Print estimated marginals and ratios!
dellaert Jan 1, 2023
dbd9faf
Fix quality testing
dellaert Jan 1, 2023
3d821ec
Now test elimination in c++
dellaert Jan 1, 2023
0095f73
attempt to fix elimination
dellaert Jan 1, 2023
665cb29
Make testcase exactly 5.0 mean
dellaert Jan 1, 2023
2c7b3a2
Refactoring in elimination
dellaert Jan 1, 2023
4d313fa
Comment on constant
dellaert Jan 1, 2023
064f17b
Added two-measurement example
dellaert Jan 2, 2023
312ba5f
Synced two examples
dellaert Jan 2, 2023
7c27061
Added missing methods
dellaert Jan 2, 2023
bd8d2ea
Added error for all versions - should become logDiensity?
dellaert Jan 2, 2023
021ee1a
Deterministic example, much more generic importance sampler
dellaert Jan 2, 2023
fbfc20b
Fixed conversion arguments
dellaert Jan 2, 2023
06aed53
rename
dellaert Jan 2, 2023
f8d75ab
name change of Sum to GaussianFactorGraphTree and SumFrontals to asse…
dellaert Jan 2, 2023
12d02be
Right marginals for tiny1
dellaert Jan 2, 2023
797ac34
Same correct error with factor_z.error()
dellaert Jan 2, 2023
625977e
Example with 2 measurements agrees with importance sampling
dellaert Jan 2, 2023
c3f0469
Add mean to test
dellaert Jan 2, 2023
f726cf6
f(x0, x1, m0; z0, z1) now has constant ratios !
dellaert Jan 2, 2023
66b846f
Merge branch 'hybrid/elimination' into hybrid/test_with_evaluate
varunagrawal Jan 3, 2023
38f3209
fix GaussianConditional print test
varunagrawal Jan 3, 2023
195dddf
clean up HybridGaussianFactorGraph
varunagrawal Jan 3, 2023
47346c5
move GraphAndConstant traits definition to HybridFactor
varunagrawal Jan 3, 2023
ca1c517
remove extra print statements
varunagrawal Jan 3, 2023
7825ffd
fix tests due to change to EliminateDiscrete
varunagrawal Jan 3, 2023
f117da2
remove extra print
varunagrawal Jan 3, 2023
cb885fb
check for nullptr in HybridConditional::equals
varunagrawal Jan 3, 2023
46acba5
serialize inner_, need to test
varunagrawal Jan 3, 2023
41c73fd
comment out failing tests, need to serialize DecisionTree
varunagrawal Jan 3, 2023
e01f7e7
kill unnecessary method
varunagrawal Jan 3, 2023
9e7fcc8
make header functions as inline
varunagrawal Jan 3, 2023
3771d63
simplify HybridConditional equality check
varunagrawal Jan 3, 2023
385ae34
Merge pull request #1363 from borglab/hybrid/test_with_evaluate-2
varunagrawal Jan 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,28 +51,28 @@ GaussianMixture::GaussianMixture(
Conditionals(discreteParents, conditionalsList)) {}

/* *******************************************************************************/
GaussianMixture::Sum GaussianMixture::add(
const GaussianMixture::Sum &sum) const {
using Y = GaussianMixtureFactor::GraphAndConstant;
GaussianFactorGraphTree GaussianMixture::add(
const GaussianFactorGraphTree &sum) const {
using Y = GraphAndConstant;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1.graph;
result.push_back(graph2.graph);
return Y(result, graph1.constant + graph2.constant);
};
const Sum tree = asGaussianFactorGraphTree();
const auto tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add);
}

/* *******************************************************************************/
GaussianMixture::Sum GaussianMixture::asGaussianFactorGraphTree() const {
GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
auto lambda = [](const GaussianConditional::shared_ptr &conditional) {
GaussianFactorGraph result;
result.push_back(conditional);
if (conditional) {
return GaussianMixtureFactor::GraphAndConstant(
return GraphAndConstant(
result, conditional->logNormalizationConstant());
} else {
return GaussianMixtureFactor::GraphAndConstant(result, 0.0);
return GraphAndConstant(result, 0.0);
}
};
return {conditionals_, lambda};
Expand Down Expand Up @@ -103,7 +103,19 @@ GaussianConditional::shared_ptr GaussianMixture::operator()(
/* *******************************************************************************/
bool GaussianMixture::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf);
return e != nullptr && BaseFactor::equals(*e, tol);
if (e == nullptr) return false;

// This will return false if either conditionals_ is empty or e->conditionals_
// is empty, but not if both are empty or both are not empty:
if (conditionals_.empty() ^ e->conditionals_.empty()) return false;

// Check the base and the factors:
return BaseFactor::equals(*e, tol) &&
conditionals_.equals(e->conditionals_,
[tol](const GaussianConditional::shared_ptr &f1,
const GaussianConditional::shared_ptr &f2) {
return f1->equals(*(f2), tol);
});
}

/* *******************************************************************************/
Expand Down
19 changes: 13 additions & 6 deletions gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,6 @@ class GTSAM_EXPORT GaussianMixture
using BaseFactor = HybridFactor;
using BaseConditional = Conditional<HybridFactor, GaussianMixture>;

/// Alias for DecisionTree of GaussianFactorGraphs
using Sum = DecisionTree<Key, GaussianMixtureFactor::GraphAndConstant>;

/// typedef for Decision Tree of Gaussian Conditionals
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;

Expand All @@ -71,7 +68,7 @@ class GTSAM_EXPORT GaussianMixture
/**
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
*/
Sum asGaussianFactorGraphTree() const;
GaussianFactorGraphTree asGaussianFactorGraphTree() const;

/**
* @brief Helper function to get the pruner functor.
Expand Down Expand Up @@ -172,6 +169,16 @@ class GTSAM_EXPORT GaussianMixture
*/
double error(const HybridValues &values) const override;

// /// Calculate probability density for given values `x`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we keeping these for the future? My intuition says yes, but this would probably move to a base class.

// double evaluate(const HybridValues &values) const;

// /// Evaluate probability density, sugar.
// double operator()(const HybridValues &values) const { return
// evaluate(values); }

// /// Calculate log-density for given values `x`.
// double logDensity(const HybridValues &values) const;

/**
* @brief Prune the decision tree of Gaussian factors as per the discrete
* `decisionTree`.
Expand All @@ -186,9 +193,9 @@ class GTSAM_EXPORT GaussianMixture
* maintaining the decision tree structure.
*
* @param sum Decision Tree of Gaussian Factor Graphs
* @return Sum
* @return GaussianFactorGraphTree
*/
Sum add(const Sum &sum) const;
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
/// @}
};

Expand Down
24 changes: 14 additions & 10 deletions gtsam/hybrid/GaussianMixtureFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,32 +81,36 @@ void GaussianMixtureFactor::print(const std::string &s,
}

/* *******************************************************************************/
const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() const {
return Mixture(factors_, [](const FactorAndConstant &factor_z) {
return factor_z.factor;
});
GaussianFactor::shared_ptr GaussianMixtureFactor::factor(
const DiscreteValues &assignment) const {
return factors_(assignment).factor;
}

/* *******************************************************************************/
GaussianMixtureFactor::Sum GaussianMixtureFactor::add(
const GaussianMixtureFactor::Sum &sum) const {
using Y = GaussianMixtureFactor::GraphAndConstant;
double GaussianMixtureFactor::constant(const DiscreteValues &assignment) const {
return factors_(assignment).constant;
}

/* *******************************************************************************/
GaussianFactorGraphTree GaussianMixtureFactor::add(
const GaussianFactorGraphTree &sum) const {
using Y = GraphAndConstant;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1.graph;
result.push_back(graph2.graph);
return Y(result, graph1.constant + graph2.constant);
};
const Sum tree = asGaussianFactorGraphTree();
const auto tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add);
}

/* *******************************************************************************/
GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
const {
auto wrap = [](const FactorAndConstant &factor_z) {
GaussianFactorGraph result;
result.push_back(factor_z.factor);
return GaussianMixtureFactor::GraphAndConstant(result, factor_z.constant);
return GraphAndConstant(result, factor_z.constant);
};
return {factors_, wrap};
}
Expand Down
34 changes: 12 additions & 22 deletions gtsam/hybrid/GaussianMixtureFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
// Note: constant is log of normalization constant for probabilities.
// Errors is the negative log-likelihood,
// hence we subtract the constant here.
if (!factor) return 0.0; // If nullptr, return 0.0 error
return factor->error(values) - constant;
}

Expand All @@ -71,22 +72,6 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
}
};

/// Gaussian factor graph and log of normalizing constant.
struct GraphAndConstant {
GaussianFactorGraph graph;
double constant;

GraphAndConstant(const GaussianFactorGraph &graph, double constant)
: graph(graph), constant(constant) {}

// Check pointer equality.
bool operator==(const GraphAndConstant &other) const {
return graph == other.graph && constant == other.constant;
}
};

using Sum = DecisionTree<Key, GraphAndConstant>;

/// typedef for Decision Tree of Gaussian factors and log-constant.
using Factors = DecisionTree<Key, FactorAndConstant>;
using Mixture = DecisionTree<Key, sharedFactor>;
Expand All @@ -99,9 +84,9 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @brief Helper function to return factors and functional to create a
* DecisionTree of Gaussian Factor Graphs.
*
* @return Sum (DecisionTree<Key, GaussianFactorGraph>)
* @return GaussianFactorGraphTree
*/
Sum asGaussianFactorGraphTree() const;
GaussianFactorGraphTree asGaussianFactorGraphTree() const;

public:
/// @name Constructors
Expand Down Expand Up @@ -151,12 +136,16 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
void print(
const std::string &s = "GaussianMixtureFactor\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override;

/// @}
/// @name Standard API
/// @{

/// Getter for the underlying Gaussian Factor Decision Tree.
const Mixture factors() const;
/// Get factor at a given discrete assignment.
sharedFactor factor(const DiscreteValues &assignment) const;

/// Get constant at a given discrete assignment.
double constant(const DiscreteValues &assignment) const;

/**
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
Expand All @@ -166,7 +155,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* variables.
* @return Sum
*/
Sum add(const Sum &sum) const;
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;

/**
* @brief Compute error of the GaussianMixtureFactor as a tree.
Expand All @@ -184,7 +173,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
double error(const HybridValues &values) const override;

/// Add MixtureFactor to a Sum, syntactic sugar.
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
friend GaussianFactorGraphTree &operator+=(
GaussianFactorGraphTree &sum, const GaussianMixtureFactor &factor) {
sum = factor.add(sum);
return sum;
}
Expand Down
14 changes: 14 additions & 0 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ static std::mt19937_64 kRandomNumberGenerator(42);

namespace gtsam {

/* ************************************************************************* */
void HybridBayesNet::print(const std::string &s,
const KeyFormatter &formatter) const {
Base::print(s, formatter);
}

/* ************************************************************************* */
bool HybridBayesNet::equals(const This &bn, double tol) const {
return Base::equals(bn, tol);
}

/* ************************************************************************* */
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> decisionTree;
Expand Down Expand Up @@ -271,12 +282,15 @@ double HybridBayesNet::evaluate(const HybridValues &values) const {

// Iterate over each conditional.
for (auto &&conditional : *this) {
// TODO: should be delegated to derived classes.
if (auto gm = conditional->asMixture()) {
const auto component = (*gm)(discreteValues);
logDensity += component->logDensity(continuousValues);

} else if (auto gc = conditional->asGaussian()) {
// If continuous only, evaluate the probability and multiply.
logDensity += gc->logDensity(continuousValues);

} else if (auto dc = conditional->asDiscrete()) {
// Conditional is discrete-only, so return its probability.
probability *= dc->operator()(discreteValues);
Expand Down
14 changes: 5 additions & 9 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,14 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// @name Testable
/// @{

/** Check equality */
bool equals(const This &bn, double tol = 1e-9) const {
return Base::equals(bn, tol);
}

/// print graph
/// GTSAM-style printing
void print(
const std::string &s = "",
const KeyFormatter &formatter = DefaultKeyFormatter) const override {
Base::print(s, formatter);
}
const KeyFormatter &formatter = DefaultKeyFormatter) const override;

/// GTSAM-style equals
bool equals(const This& fg, double tol = 1e-9) const;

/// @}
/// @name Standard Interface
/// @{
Expand Down
34 changes: 33 additions & 1 deletion gtsam/hybrid/HybridConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/inference/Key.h>

Expand Down Expand Up @@ -102,7 +103,38 @@ void HybridConditional::print(const std::string &s,
/* ************************************************************************ */
bool HybridConditional::equals(const HybridFactor &other, double tol) const {
const This *e = dynamic_cast<const This *>(&other);
return e != nullptr && BaseFactor::equals(*e, tol);
if (e == nullptr) return false;
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
if (auto gm = asMixture()) {
auto other = e->asMixture();
return other != nullptr && gm->equals(*other, tol);
}
if (auto gc = asGaussian()) {
auto other = e->asGaussian();
return other != nullptr && gc->equals(*other, tol);
}
if (auto dc = asDiscrete()) {
auto other = e->asDiscrete();
return other != nullptr && dc->equals(*other, tol);
}
return inner_->equals(*(e->inner_), tol);

return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
: !(e->inner_);
}

/* ************************************************************************ */
double HybridConditional::error(const HybridValues &values) const {
if (auto gm = asMixture()) {
return gm->error(values);
}
if (auto gc = asGaussian()) {
return gc->error(values.continuous());
}
if (auto dc = asDiscrete()) {
return -log((*dc)(values.discrete()));
}
throw std::runtime_error(
"HybridConditional::error: conditional type not handled");
}

} // namespace gtsam
11 changes: 2 additions & 9 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,7 @@ class GTSAM_EXPORT HybridConditional
boost::shared_ptr<Factor> inner() const { return inner_; }

/// Return the error of the underlying conditional.
/// Currently only implemented for Gaussian mixture.
double error(const HybridValues& values) const override {
if (auto gm = asMixture()) {
return gm->error(values);
} else {
throw std::runtime_error(
"HybridConditional::error: only implemented for Gaussian mixture");
}
}
double error(const HybridValues& values) const override;

/// @}

Expand All @@ -195,6 +187,7 @@ class GTSAM_EXPORT HybridConditional
void serialize(Archive& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
ar& BOOST_SERIALIZATION_NVP(inner_);
}

}; // HybridConditional
Expand Down
Loading