Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved documentation, wrapper, and tests #1574

Merged
merged 15 commits into from
Jul 16, 2023
64 changes: 56 additions & 8 deletions gtsam/discrete/AlgebraicDecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
namespace gtsam {

/**
* Algebraic Decision Trees fix the range to double
* Just has some nice constructors and some syntactic sugar
* TODO: consider eliminating this class altogether?
* An algebraic decision tree fixes the range of a DecisionTree to double.
* Just has some nice constructors and some syntactic sugar.
* TODO(dellaert): consider eliminating this class altogether?
*
* @ingroup discrete
*/
Expand Down Expand Up @@ -80,20 +80,62 @@ namespace gtsam {
AlgebraicDecisionTree(const L& label, double y1, double y2)
: Base(label, y1, y2) {}

/** Create a new leaf function splitting on a variable */
/**
* @brief Create a new leaf function splitting on a variable
*
* @param labelC: The label with cardinality 2
* @param y1: The value for the first key
* @param y2: The value for the second key
*
* Example:
* @code{.cpp}
* std::pair<string, size_t> A {"a", 2};
* AlgebraicDecisionTree<string> a(A, 0.6, 0.4);
* @endcode
*/
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
double y2)
: Base(labelC, y1, y2) {}

/** Create from keys and vector table */
/**
* @brief Create from keys with cardinalities and a vector table
*
* @param labelCs: The keys, with cardinalities, given as pairs
* @param ys: The vector table
*
* Example with three keys, A, B, and C, with cardinalities 2, 3, and 2,
* respectively, and a vector table of size 12:
* @code{.cpp}
* DiscreteKey A(0, 2), B(1, 3), C(2, 2);
* const vector<double> cpt{
* 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
* 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10};
* AlgebraicDecisionTree<Key> expected(A & B & C, cpt);
* @endcode
* The table is given in the following order:
* A=0, B=0, C=0
* A=0, B=0, C=1
* ...
* A=1, B=1, C=1
* Hence, the first line in the table is for A==0, and the second for A==1.
* In each line, the first two entries are for B==0, the next two for B==1,
* and the last two for B==2. Each pair is for a C value of 0 and 1.
*/
AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs,
const std::vector<double>& ys) {
const std::vector<double>& ys) {
this->root_ =
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
}

/** Create from keys and string table */
/**
* @brief Create from keys and string table
*
* @param labelCs: The keys, with cardinalities, given as pairs
* @param table: The string table, given as a string of doubles.
*
* @note Table needs to be in same order as the vector table in the other constructor.
*/
AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs,
const std::string& table) {
Expand All @@ -108,7 +150,13 @@ namespace gtsam {
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
}

/** Create a new function splitting on a variable */
/**
* @brief Create a range of decision trees, splitting on a single variable.
*
* @param begin: Iterator to beginning of a range of decision trees
* @param end: Iterator to end of a range of decision trees
* @param label: The label to split on
*/
template <typename Iterator>
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
: Base(nullptr) {
Expand Down
2 changes: 1 addition & 1 deletion gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ namespace gtsam {
// B=1
// A=0: 3
// A=1: 4
// Note, through the magic of "compose", create([A B],[1 2 3 4]) will produce
// Note, through the magic of "compose", create([A B],[1 3 2 4]) will produce
// exactly the same tree as above: the highest label is always the root.
// However, it will be *way* faster if labels are given highest to lowest.
template<typename L, typename Y>
Expand Down
31 changes: 26 additions & 5 deletions gtsam/discrete/DecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,23 @@
namespace gtsam {

/**
* Decision Tree
* L = label for variables
* Y = function range (any algebra), e.g., bool, int, double
* @brief a decision tree is a function from assignments to values.
* @tparam L label for variables
* @tparam Y function range (any algebra), e.g., bool, int, double
*
* After creating a decision tree on some variables, the tree can be evaluated
* on an assignment to those variables. Example:
*
* @code{.cpp}
* // Create a decision stump one one variable 'a' with values 10 and 20.
* DecisionTree<char, int> tree('a', 10, 20);
*
* // Evaluate the tree on an assignment to the variable.
* int value0 = tree({{'a', 0}}); // value0 = 10
* int value1 = tree({{'a', 1}}); // value1 = 20
* @endcode
*
* More examples can be found in testDecisionTree.cpp
*
* @ingroup discrete
*/
Expand Down Expand Up @@ -132,7 +146,8 @@ namespace gtsam {
NodePtr root_;

protected:
/** Internal recursive function to create from keys, cardinalities,
/**
* Internal recursive function to create from keys, cardinalities,
* and Y values
*/
template<typename It, typename ValueIt>
Expand Down Expand Up @@ -163,7 +178,13 @@ namespace gtsam {
/** Create a constant */
explicit DecisionTree(const Y& y);

/// Create tree with 2 assignments `y1`, `y2`, splitting on variable `label`
/**
* @brief Create tree with 2 assignments `y1`, `y2`, splitting on variable `label`
*
* @param label The variable to split on.
* @param y1 The value for the first assignment.
* @param y2 The value for the second assignment.
*/
DecisionTree(const L& label, const Y& y1, const Y& y2);

/** Allow Label+Cardinality for convenience */
Expand Down
41 changes: 38 additions & 3 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,46 @@ namespace gtsam {
/** Constructor from DiscreteKeys and AlgebraicDecisionTree */
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);

/** Constructor from doubles */
/**
* @brief Constructor from doubles
*
* @param keys The discrete keys.
* @param table The table of values.
*
* @throw std::invalid_argument if the size of `table` does not match the
* number of assignments.
*
* Example:
* @code{.cpp}
* DiscreteKey X(0,2), Y(1,3);
* const std::vector<double> table {2, 5, 3, 6, 4, 7};
* DecisionTreeFactor f1({X, Y}, table);
* @endcode
*
* The values in the table should be laid out so that the first key varies
* the slowest, and the last key the fastest.
*/
DecisionTreeFactor(const DiscreteKeys& keys,
const std::vector<double>& table);
const std::vector<double>& table);

/** Constructor from string */
/**
* @brief Constructor from string
*
* @param keys The discrete keys.
* @param table The table of values.
*
* @throw std::invalid_argument if the size of `table` does not match the
* number of assignments.
*
* Example:
* @code{.cpp}
* DiscreteKey X(0,2), Y(1,3);
* DecisionTreeFactor factor({X, Y}, "2 5 3 6 4 7");
* @endcode
*
* The values in the table should be laid out so that the first key varies
* the slowest, and the last key the fastest.
*/
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);

/// Single-key specialization
Expand Down
5 changes: 5 additions & 0 deletions gtsam/discrete/DiscreteBayesTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ class GTSAM_EXPORT DiscreteBayesTreeClique

//** evaluate conditional probability of subtree for given DiscreteValues */
double evaluate(const DiscreteValues& values) const;

//** (Preferred) sugar for the above for given DiscreteValues */
double operator()(const DiscreteValues& values) const {
return evaluate(values);
}
};

/* ************************************************************************* */
Expand Down
39 changes: 25 additions & 14 deletions gtsam/discrete/DiscreteFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,30 @@ class DiscreteJunctionTree;

/**
* @brief Main elimination function for DiscreteFactorGraph.
*
* @param factors
* @param keys
* @return GTSAM_EXPORT
*
* @param factors The factor graph to eliminate.
* @param frontalKeys An ordering for which variables to eliminate.
* @return A pair of the resulting conditional and the separator factor.
* @ingroup discrete
*/
GTSAM_EXPORT std::pair<boost::shared_ptr<DiscreteConditional>, DecisionTreeFactor::shared_ptr>
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& keys);
GTSAM_EXPORT
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);

/**
* @brief Alternate elimination function for that creates non-normalized lookup tables.
*
* @param factors The factor graph to eliminate.
* @param frontalKeys An ordering for which variables to eliminate.
* @return A pair of the resulting lookup table and the separator factor.
* @ingroup discrete
*/
GTSAM_EXPORT
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);

/* ************************************************************************* */
template<> struct EliminationTraits<DiscreteFactorGraph>
{
typedef DiscreteFactor FactorType; ///< Type of factors in factor graph
Expand All @@ -61,12 +75,14 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
typedef DiscreteEliminationTree EliminationTreeType; ///< Type of elimination tree
typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree
typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree

/// The default dense elimination function
static std::pair<boost::shared_ptr<ConditionalType>,
boost::shared_ptr<FactorType> >
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
return EliminateDiscrete(factors, keys);
}

/// The default ordering generation function
static Ordering DefaultOrderingFunc(
const FactorGraphType& graph,
Expand All @@ -75,7 +91,6 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
}
};

/* ************************************************************************* */
/**
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
* Factor == DiscreteFactor
Expand Down Expand Up @@ -109,8 +124,8 @@ class GTSAM_EXPORT DiscreteFactorGraph

/** Implicit copy/downcast constructor to override explicit template container
* constructor */
template <class DERIVEDFACTOR>
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
template <class DERIVED_FACTOR>
DiscreteFactorGraph(const FactorGraph<DERIVED_FACTOR>& graph) : Base(graph) {}

/// Destructor
virtual ~DiscreteFactorGraph() {}
Expand Down Expand Up @@ -231,10 +246,6 @@ class GTSAM_EXPORT DiscreteFactorGraph
/// @}
}; // \ DiscreteFactorGraph

std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);

/// traits
template <>
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
Expand Down
2 changes: 2 additions & 0 deletions gtsam/discrete/DiscreteJunctionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,6 @@ namespace gtsam {
DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree);
};

/// typedef for wrapper:
using DiscreteCluster = DiscreteJunctionTree::Cluster;
}
5 changes: 5 additions & 0 deletions gtsam/discrete/DiscreteValues.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ class GTSAM_EXPORT DiscreteValues : public Assignment<Key> {
/// @}
};

/// Free version of CartesianProduct.
inline std::vector<DiscreteValues> cartesianProduct(const DiscreteKeys& keys) {
return DiscreteValues::CartesianProduct(keys);
}

/// Free version of markdown.
std::string markdown(const DiscreteValues& values,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
Expand Down
Loading