Skip to content

Commit

Permalink
Merge pull request #17 from varunagrawal/feature/approximate_discrete
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Feb 9, 2022
2 parents c169ba7 + 34f9ab9 commit 25cacfe
Show file tree
Hide file tree
Showing 20 changed files with 468 additions and 130 deletions.
5 changes: 1 addition & 4 deletions gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,7 @@ namespace gtsam {
}

os << "\"" << this->id() << "\" -> \"" << branch->id() << "\"";
if (B == 2) {
if (i == 0) os << " [style=dashed]";
if (i > 1) os << " [style=bold]";
}
if (B == 2 && i == 0) os << " [style=dashed]";
os << std::endl;
branch->dot(os, labelFormatter, valueFormatter, showZero);
}
Expand Down
36 changes: 34 additions & 2 deletions gtsam/discrete/tests/testDiscreteBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteMarginals.h>
#include <gtsam/base/debug.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/Vector.h>

Expand All @@ -37,6 +38,8 @@ using namespace boost::assign;

using namespace std;
using namespace gtsam;
using symbol_shorthand::X;
using symbol_shorthand::M;

static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2),
LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
Expand Down Expand Up @@ -102,7 +105,7 @@ TEST(DiscreteBayesNet, Asia) {
// Create solver and eliminate
Ordering ordering;
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7);
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
auto chordal = fg.eliminateSequential(ordering);
DiscreteConditional expected2(Bronchitis % "11/9");
EXPECT(assert_equal(expected2, *chordal->back()));

Expand All @@ -111,7 +114,7 @@ TEST(DiscreteBayesNet, Asia) {
fg.add(Dyspnea, "0 1");

// solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
auto chordal2 = fg.eliminateSequential(ordering);
EXPECT(assert_equal(expected2, *chordal->back()));

// now sample from it
Expand Down Expand Up @@ -167,6 +170,35 @@ TEST(DiscreteBayesNet, Dot) {
"}");
}

/* ************************************************************************* */
TEST(DiscreteBayesTree, Switching) {
size_t nrStates = 3;
size_t K = 5;

DiscreteBayesNet bayesNet;

// Add "motion models".
for (size_t k = 1; k < K; k++) {
DiscreteKey key(X(k), nrStates), key_plus(X(k + 1), nrStates),
mode(M(k), 2);
bayesNet.add(DiscreteConditional(key_plus, {key, mode},
"1/1/1 1/2/1 3/2/3 1/1/1 1/2/1 3/2/3"));
}

// Add "mode chain"
for (size_t k = 1; k < K - 1; k++) {
DiscreteKey mode(M(k), 2), mode_plus(M(k + 1), 2);
bayesNet.add(DiscreteConditional(mode_plus, {mode}, "1/2 3/2"));
}

// eliminate: because D>C, discrete keys get eliminated last:
Ordering ordering;
for (size_t k = 1; k <= K; k++) ordering += X(k);
for (size_t k = 1; k < K; k++) ordering += M(k);
auto chordal = DiscreteFactorGraph(bayesNet).eliminateSequential(ordering);
GTSAM_PRINT(*chordal);
}

/* ************************************************************************* */
// Check markdown representation looks as expected.
TEST(DiscreteBayesNet, markdown) {
Expand Down
16 changes: 11 additions & 5 deletions gtsam/inference/AbstractConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,26 @@

/**
* @file AbstractConditional.cpp
* @brief Concrete base class for conditional densities
* @brief Abstract base class for conditional densities
* @author Fan Jiang
*/


#include <gtsam/inference/AbstractConditional.h>

namespace gtsam {
/* ************************************************************************* */
void AbstractConditional::print(const std::string &s,
const KeyFormatter &formatter) const {
throw std::runtime_error("AbstractConditional::print not implemented!");
}

/* ************************************************************************* */
bool AbstractConditional::equals(const AbstractConditional &c,
double tol) const {
throw std::invalid_argument("You are calling the base AbstractConditional's"
" equality, which is illegal.");
throw std::invalid_argument(
"You are calling the base AbstractConditional's"
" equality, which is illegal.");
return nrFrontals_ == c.nrFrontals_;
}

}
} // namespace gtsam
13 changes: 6 additions & 7 deletions gtsam/inference/AbstractConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@

/**
* @file AbstractConditional.h
* @brief Concrete base class for conditional densities
* @brief Abstract base class for conditional densities
* @author Fan Jiang
*/

// \callgraph
#pragma once

#include <boost/range.hpp>

#include <gtsam/inference/Key.h>

#include <boost/range.hpp>

namespace gtsam {

class GTSAM_EXPORT AbstractConditional {
Expand Down Expand Up @@ -58,8 +58,8 @@ class GTSAM_EXPORT AbstractConditional {
/// @{

/** print with optional formatter */
virtual void print(const std::string &s = "Conditional",
const KeyFormatter &formatter = DefaultKeyFormatter) const = 0;
virtual void print(const std::string &s = "AbstractConditional",
const KeyFormatter &formatter = DefaultKeyFormatter) const;

/** check equality */
bool equals(const AbstractConditional &c, double tol = 1e-9) const;
Expand All @@ -78,7 +78,6 @@ class GTSAM_EXPORT AbstractConditional {
/** return the number of parents */
virtual size_t nrParents() const = 0;


/** return a view of the frontal keys */
virtual Frontals frontals() const = 0;

Expand All @@ -90,4 +89,4 @@ class GTSAM_EXPORT AbstractConditional {
template <>
struct traits<AbstractConditional> : public Testable<AbstractConditional> {};

}
} // namespace gtsam
55 changes: 40 additions & 15 deletions gtsam/linear/GaussianBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
using namespace std;
using namespace gtsam;

// In Wrappers we have no access to this so have a default ready
static std::mt19937_64 kRandomNumberGenerator(42);

namespace gtsam {

// Instantiate base class
Expand All @@ -37,28 +40,50 @@ namespace gtsam {
return Base::equals(bn, tol);
}

/* ************************************************************************* */
VectorValues GaussianBayesNet::optimize() const
{
VectorValues soln; // no missing variables -> just create an empty vector
return optimize(soln);
/* ************************************************************************ */
VectorValues GaussianBayesNet::optimize() const {
VectorValues solution; // no missing variables -> create an empty vector
return optimize(solution);
}

/* ************************************************************************* */
VectorValues GaussianBayesNet::optimize(
const VectorValues& solutionForMissing) const {
VectorValues soln(solutionForMissing); // possibly empty
VectorValues GaussianBayesNet::optimize(VectorValues solution) const {
// (R*x)./sigmas = y by solving x=inv(R)*(y.*sigmas)
/** solve each node in turn in topological sort order (parents first)*/
for (auto cg: boost::adaptors::reverse(*this)) {
// solve each node in reverse topological sort order (parents first)
for (auto cg : boost::adaptors::reverse(*this)) {
// i^th part of R*x=y, x=inv(R)*y
// (Rii*xi + R_i*x(i+1:))./si = yi <-> xi = inv(Rii)*(yi.*si - R_i*x(i+1:))
soln.insert(cg->solve(soln));
// (Rii*xi + R_i*x(i+1:))./si = yi =>
// xi = inv(Rii)*(yi.*si - R_i*x(i+1:))
solution.insert(cg->solve(solution));
}
return soln;
return solution;
}

/* ************************************************************************* */
/* ************************************************************************ */
VectorValues GaussianBayesNet::sample(std::mt19937_64* rng) const {
VectorValues result; // no missing variables -> create an empty vector
return sample(result, rng);
}

VectorValues GaussianBayesNet::sample(VectorValues result,
std::mt19937_64* rng) const {
// sample each node in reverse topological sort order (parents first)
for (auto cg : boost::adaptors::reverse(*this)) {
const VectorValues sampled = cg->sample(result, rng);
result.insert(sampled);
}
return result;
}

/* ************************************************************************ */
VectorValues GaussianBayesNet::sample() const {
return sample(&kRandomNumberGenerator);
}

VectorValues GaussianBayesNet::sample(VectorValues given) const {
return sample(given, &kRandomNumberGenerator);
}

/* ************************************************************************ */
VectorValues GaussianBayesNet::optimizeGradientSearch() const
{
gttic(GaussianBayesTree_optimizeGradientSearch);
Expand Down
30 changes: 27 additions & 3 deletions gtsam/linear/GaussianBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,35 @@ namespace gtsam {
/// @name Standard Interface
/// @{

/// Solve the GaussianBayesNet, i.e. return \f$ x = R^{-1}*d \f$, by back-substitution
/// Solve the GaussianBayesNet, i.e. return \f$ x = R^{-1}*d \f$, by
/// back-substitution
VectorValues optimize() const;

/// Version of optimize for incomplete BayesNet, needs solution for missing variables
VectorValues optimize(const VectorValues& solutionForMissing) const;
/// Version of optimize for incomplete BayesNet, given missing variables
VectorValues optimize(const VectorValues given) const;

/**
* Sample using ancestral sampling
* Example:
* std::mt19937_64 rng(42);
* auto sample = gbn.sample(&rng);
*/
VectorValues sample(std::mt19937_64* rng) const;

/**
* Sample from an incomplete BayesNet, given missing variables
* Example:
* std::mt19937_64 rng(42);
* VectorValues given = ...;
* auto sample = gbn.sample(given, &rng);
*/
VectorValues sample(VectorValues given, std::mt19937_64* rng) const;

/// Sample using ancestral sampling, use default rng
VectorValues sample() const;

/// Sample from an incomplete BayesNet, use default rng
VectorValues sample(VectorValues given) const;

/**
* Return ordering corresponding to a topological sort.
Expand Down
Loading

0 comments on commit 25cacfe

Please sign in to comment.