Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better DiscreteConditional #1037

Merged
merged 7 commits into from
Jan 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ namespace gtsam {
/** Default constructor for I/O */
DecisionTreeFactor();

/** Constructor from Indices, Ordering, and AlgebraicDecisionDiagram */
/** Constructor from DiscreteKeys and AlgebraicDecisionTree */
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);

/** Constructor from doubles */
Expand Down Expand Up @@ -139,22 +139,22 @@ namespace gtsam {
/**
* Apply binary operator (*this) "op" f
* @param f the second argument for op
* @param op a binary operator that operates on AlgebraicDecisionDiagram potentials
* @param op a binary operator that operates on AlgebraicDecisionTree
*/
DecisionTreeFactor apply(const DecisionTreeFactor& f, ADT::Binary op) const;

/**
* Combine frontal variables using binary operator "op"
* @param nrFrontals nr. of frontal to combine variables in this factor
* @param op a binary operator that operates on AlgebraicDecisionDiagram potentials
* @param op a binary operator that operates on AlgebraicDecisionTree
* @return shared pointer to newly created DecisionTreeFactor
*/
shared_ptr combine(size_t nrFrontals, ADT::Binary op) const;

/**
* Combine frontal variables in an Ordering using binary operator "op"
* @param nrFrontals nr. of frontal to combine variables in this factor
* @param op a binary operator that operates on AlgebraicDecisionDiagram potentials
* @param op a binary operator that operates on AlgebraicDecisionTree
* @return shared pointer to newly created DecisionTreeFactor
*/
shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
Expand Down
98 changes: 79 additions & 19 deletions gtsam/discrete/DiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <string>
#include <vector>
#include <utility>
#include <set>

using namespace std;
using std::stringstream;
Expand All @@ -38,38 +39,97 @@ using std::pair;
namespace gtsam {

// Instantiate base class
template class GTSAM_EXPORT Conditional<DecisionTreeFactor, DiscreteConditional> ;
template class GTSAM_EXPORT
Conditional<DecisionTreeFactor, DiscreteConditional>;

/* ******************************************************************************** */
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
const DecisionTreeFactor& f) :
BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {
}
const DecisionTreeFactor& f)
: BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}

/* ******************************************************************************** */
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
const DiscreteKeys& keys,
const ADT& potentials)
: BaseFactor(keys, potentials), BaseConditional(nrFrontals) {}

/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal) :
BaseFactor(
ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal), BaseConditional(
joint.size()-marginal.size()) {
if (ISDEBUG("DiscreteConditional::DiscreteConditional"))
cout << (firstFrontalKey()) << endl; //TODO Print all keys
}
const DecisionTreeFactor& marginal)
: BaseFactor(joint / marginal),
BaseConditional(joint.size() - marginal.size()) {}

/* ******************************************************************************** */
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal, const Ordering& orderedKeys) :
DiscreteConditional(joint, marginal) {
const DecisionTreeFactor& marginal,
const Ordering& orderedKeys)
: DiscreteConditional(joint, marginal) {
keys_.clear();
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
}

/* ******************************************************************************** */
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const Signature& signature)
: BaseFactor(signature.discreteKeys(), signature.cpt()),
BaseConditional(1) {}

/* ******************************************************************************** */
/* ************************************************************************** */
DiscreteConditional DiscreteConditional::operator*(
const DiscreteConditional& other) const {
// Take union of frontal keys
std::set<Key> newFrontals;
for (auto&& key : this->frontals()) newFrontals.insert(key);
for (auto&& key : other.frontals()) newFrontals.insert(key);

// Check if frontals overlapped
if (nrFrontals() + other.nrFrontals() > newFrontals.size())
throw std::invalid_argument(
"DiscreteConditional::operator* called with overlapping frontal keys.");

// Now, add cardinalities.
DiscreteKeys discreteKeys;
for (auto&& key : frontals())
discreteKeys.emplace_back(key, cardinality(key));
for (auto&& key : other.frontals())
discreteKeys.emplace_back(key, other.cardinality(key));

// Sort
std::sort(discreteKeys.begin(), discreteKeys.end());

// Add parents to set, to make them unique
std::set<DiscreteKey> parents;
for (auto&& key : this->parents())
if (!newFrontals.count(key)) parents.emplace(key, cardinality(key));
for (auto&& key : other.parents())
if (!newFrontals.count(key)) parents.emplace(key, other.cardinality(key));

// Finally, add parents to keys, in order
for (auto&& dk : parents) discreteKeys.push_back(dk);

ADT product = ADT::apply(other, ADT::Ring::mul);
return DiscreteConditional(newFrontals.size(), discreteKeys, product);
}

/* ************************************************************************** */
DiscreteConditional DiscreteConditional::marginal(Key key) const {
if (nrParents() > 0)
throw std::invalid_argument(
"DiscreteConditional::marginal: single argument version only valid for "
"fully specified joint distributions (i.e., no parents).");

// Calculate the keys as the frontal keys without the given key.
DiscreteKeys discreteKeys{{key, cardinality(key)}};

// Calculate sum
ADT adt(*this);
for (auto&& k : frontals())
if (k != key) adt = adt.sum(k, cardinality(k));

// Return new factor
return DiscreteConditional(1, discreteKeys, adt);
}

/* ************************************************************************** */
void DiscreteConditional::print(const string& s,
const KeyFormatter& formatter) const {
cout << s << " P( ";
Expand All @@ -82,7 +142,7 @@ void DiscreteConditional::print(const string& s,
cout << formatter(*it) << " ";
}
}
cout << ")";
cout << "):\n";
ADT::print("");
cout << endl;
}
Expand Down
77 changes: 37 additions & 40 deletions gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,21 @@ class GTSAM_EXPORT DiscreteConditional
/// @name Standard Constructors
/// @{

/** default constructor needed for serialization */
/// Default constructor needed for serialization.
DiscreteConditional() {}

/** constructor from factor */
/// Construct from factor, taking the first `nFrontals` keys as frontals.
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);

/**
* Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first
* `nFrontals` keys as frontals, in the order given.
*/
DiscreteConditional(size_t nFrontals, const DiscreteKeys& keys,
const ADT& potentials);

/** Construct from signature */
DiscreteConditional(const Signature& signature);
explicit DiscreteConditional(const Signature& signature);

/**
* Construct from key, parents, and a Signature::Table specifying the
Expand Down Expand Up @@ -86,27 +93,41 @@ class GTSAM_EXPORT DiscreteConditional
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
: DiscreteConditional(Signature(key, {}, spec)) {}

/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
/**
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
*/
DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal);

/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
/**
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
* Makes sure the keys are ordered as given. Does not check orderedKeys.
*/
DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal,
const Ordering& orderedKeys);

/**
* Combine several conditional into a single one.
* The conditionals must be given in increasing order, meaning that the
* parents of any conditional may not include a conditional coming before it.
* @param firstConditional Iterator to the first conditional to combine, must
* dereference to a shared_ptr<DiscreteConditional>.
* @param lastConditional Iterator to after the last conditional to combine,
* must dereference to a shared_ptr<DiscreteConditional>.
* */
template <typename ITERATOR>
static shared_ptr Combine(ITERATOR firstConditional,
ITERATOR lastConditional);
* @brief Combine two conditionals, yielding a new conditional with the union
* of the frontal keys, ordered by gtsam::Key.
*
* The two conditionals must make a valid Bayes net fragment, i.e.,
* the frontal variables cannot overlap, and must be acyclic:
* Example of correct use:
* P(A,B) = P(A|B) * P(B)
* P(A,B|C) = P(A|B) * P(B|C)
* P(A,B,C) = P(A,B|C) * P(C)
* Example of incorrect use:
* P(A|B) * P(A|C) = ?
* P(A|B) * P(B|A) = ?
* We check for overlapping frontals, but do *not* check for cyclic.
*/
DiscreteConditional operator*(const DiscreteConditional& other) const;

/** Calculate marginal on given key, no parent case. */
DiscreteConditional marginal(Key key) const;

/// @}
/// @name Testable
Expand Down Expand Up @@ -136,11 +157,6 @@ class GTSAM_EXPORT DiscreteConditional
return ADT::operator()(values);
}

/** Convert to a factor */
DecisionTreeFactor::shared_ptr toFactor() const {
return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this));
}

/** Restrict to given parent values, returns DecisionTreeFactor */
DecisionTreeFactor::shared_ptr choose(
const DiscreteValues& parentsValues) const;
Expand Down Expand Up @@ -208,23 +224,4 @@ class GTSAM_EXPORT DiscreteConditional
template <>
struct traits<DiscreteConditional> : public Testable<DiscreteConditional> {};

/* ************************************************************************* */
template <typename ITERATOR>
DiscreteConditional::shared_ptr DiscreteConditional::Combine(
ITERATOR firstConditional, ITERATOR lastConditional) {
// TODO: check for being a clique

// multiply all the potentials of the given conditionals
size_t nrFrontals = 0;
DecisionTreeFactor product;
for (ITERATOR it = firstConditional; it != lastConditional;
++it, ++nrFrontals) {
DiscreteConditional::shared_ptr c = *it;
DecisionTreeFactor::shared_ptr factor = c->toFactor();
product = (*factor) * product;
}
// and then create a new multi-frontal conditional
return boost::make_shared<DiscreteConditional>(nrFrontals, product);
}

} // namespace gtsam
14 changes: 7 additions & 7 deletions gtsam/discrete/DiscretePrior.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,17 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional {
DiscretePrior(const Signature& s) : Base(s) {}

/**
* Construct from key and a Signature::Table specifying the
* conditional probability table (CPT).
* Construct from key and a vector of floats specifying the probability mass
* function (PMF).
*
* Example: DiscretePrior P(D, table);
* Example: DiscretePrior P(D, {0.4, 0.6});
*/
DiscretePrior(const DiscreteKey& key, const Signature::Table& table)
: Base(Signature(key, {}, table)) {}
DiscretePrior(const DiscreteKey& key, const std::vector<double>& spec)
: DiscretePrior(Signature(key, {}, Signature::Table{spec})) {}

/**
* Construct from key and a string specifying the conditional
* probability table (CPT).
* Construct from key and a string specifying the probability mass function
* (PMF).
*
* Example: DiscretePrior P(D, "9/1 2/8 3/7 1/9");
*/
Expand Down
16 changes: 15 additions & 1 deletion gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;

double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const;
size_t cardinality(gtsam::Key j) const;
gtsam::DecisionTreeFactor operator/(const gtsam::DecisionTreeFactor& f) const;
gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const;
gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const;
gtsam::DecisionTreeFactor* max(size_t nrFrontals) const;

string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
bool showZero = true) const;
Expand Down Expand Up @@ -86,14 +95,18 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
const gtsam::DecisionTreeFactor& marginal,
const gtsam::Ordering& orderedKeys);
gtsam::DiscreteConditional operator*(
const gtsam::DiscreteConditional& other) const;
DiscreteConditional marginal(gtsam::Key key) const;
void print(string s = "Discrete Conditional\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
size_t nrFrontals() const;
size_t nrParents() const;
void printSignature(
string s = "Discrete Conditional: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
gtsam::DecisionTreeFactor* toFactor() const;
gtsam::DecisionTreeFactor* choose(
const gtsam::DiscreteValues& parentsValues) const;
gtsam::DecisionTreeFactor* likelihood(
Expand All @@ -120,6 +133,7 @@ virtual class DiscretePrior : gtsam::DiscreteConditional {
DiscretePrior();
DiscretePrior(const gtsam::DecisionTreeFactor& f);
DiscretePrior(const gtsam::DiscreteKey& key, string spec);
DiscretePrior(const gtsam::DiscreteKey& key, std::vector<double> spec);
void print(string s = "Discrete Prior\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
Expand Down
26 changes: 16 additions & 10 deletions gtsam/discrete/tests/testDecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
* @author Duy-Nguyen Ta
*/

#include <gtsam/discrete/Signature.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/base/Testable.h>
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscretePrior.h>
#include <gtsam/discrete/Signature.h>

#include <boost/assign/std/map.hpp>
using namespace boost::assign;

Expand Down Expand Up @@ -51,17 +53,21 @@ TEST( DecisionTreeFactor, constructors)
}

/* ************************************************************************* */
TEST_UNSAFE( DecisionTreeFactor, multiplication)
{
DiscreteKey v0(0,2), v1(1,2), v2(2,2);
TEST(DecisionTreeFactor, multiplication) {
DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);

// Multiply with a DiscretePrior, i.e., Bayes Law!
DiscretePrior prior(v1 % "1/3");
DecisionTreeFactor f1(v0 & v1, "1 2 3 4");
DecisionTreeFactor f2(v1 & v2, "5 6 7 8");

DecisionTreeFactor expected(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
CHECK(assert_equal(expected, static_cast<DecisionTreeFactor>(prior) * f1));
CHECK(assert_equal(expected, f1 * prior));

// Multiply two factors
DecisionTreeFactor f2(v1 & v2, "5 6 7 8");
DecisionTreeFactor actual = f1 * f2;
CHECK(assert_equal(expected, actual));
DecisionTreeFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
CHECK(assert_equal(expected2, actual));
}

/* ************************************************************************* */
Expand Down
Loading