diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 17a38f7cf3..637f56066f 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -29,7 +29,12 @@ namespace gtsam { */ template class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree { - /// Default method used by `formatter` when printing. + /** + * @brief Default method used by `labelFormatter` or `valueFormatter` when printing. + * + * @param x The value passed to format. + * @return std::string + */ static std::string DefaultFormatter(const L& x) { std::stringstream ss; ss << x; @@ -38,7 +43,7 @@ namespace gtsam { public: - typedef DecisionTree Super; + using Base = DecisionTree; /** The Real ring with addition and multiplication */ struct Ring { @@ -66,33 +71,33 @@ namespace gtsam { }; AlgebraicDecisionTree() : - Super(1.0) { + Base(1.0) { } - AlgebraicDecisionTree(const Super& add) : - Super(add) { + AlgebraicDecisionTree(const Base& add) : + Base(add) { } /** Create a new leaf function splitting on a variable */ AlgebraicDecisionTree(const L& label, double y1, double y2) : - Super(label, y1, y2) { + Base(label, y1, y2) { } /** Create a new leaf function splitting on a variable */ - AlgebraicDecisionTree(const typename Super::LabelC& labelC, double y1, double y2) : - Super(labelC, y1, y2) { + AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) : + Base(labelC, y1, y2) { } /** Create from keys and vector table */ AlgebraicDecisionTree // - (const std::vector& labelCs, const std::vector& ys) { - this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(), + (const std::vector& labelCs, const std::vector& ys) { + this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } /** Create from keys and string table */ AlgebraicDecisionTree // - (const std::vector& labelCs, const std::string& table) { + (const std::vector& labelCs, const std::string& table) { // Convert string to doubles std::vector ys; std::istringstream iss(table); @@ -100,21 +105,27 @@ namespace gtsam { std::istream_iterator(), std::back_inserter(ys)); // now call recursive Create - this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(), + this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } /** Create a new function splitting on a variable */ template AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) : - Super(nullptr) { + Base(nullptr) { this->root_ = compose(begin, end, label); } - /** Convert */ + /** + * Convert labels from type M to type L. + * + * @param other: The AlgebraicDecisionTree with label type M to convert. + * @param map: Map from label type M to label type L. + */ template AlgebraicDecisionTree(const AlgebraicDecisionTree& other, const std::map& map) { + // Functor for label conversion so we can use `convertFrom`. std::function L_of_M = [&map](const M& label) -> L { return map.at(label); }; @@ -143,18 +154,18 @@ namespace gtsam { } /** sum out variable */ - AlgebraicDecisionTree sum(const typename Super::LabelC& labelC) const { + AlgebraicDecisionTree sum(const typename Base::LabelC& labelC) const { return this->combine(labelC, &Ring::add); } /// print method customized to node type `double`. void print(const std::string& s, - const typename Super::LabelFormatter& labelFormatter = + const typename Base::LabelFormatter& labelFormatter = &DefaultFormatter) const { auto valueFormatter = [](const double& v) { return (boost::format("%4.2g") % v).str(); }; - Super::print(s, labelFormatter, valueFormatter); + Base::print(s, labelFormatter, valueFormatter); } /// Equality method customized to node type `double`. @@ -163,7 +174,7 @@ namespace gtsam { auto compare = [tol](double a, double b) { return std::abs(a - b) < tol; }; - return Super::equals(other, compare); + return Base::equals(other, compare); } }; // AlgebraicDecisionTree diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 259489f069..0c016b6c54 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -82,13 +82,19 @@ namespace gtsam { return compare(this->constant_, other->constant_); } - /** print */ + /** + * @brief Print method. + * + * @param s Prefix string. + * @param labelFormatter Functor to format the node label. + * @param valueFormatter Functor to format the node value. + */ void print(const std::string& s, const LabelFormatter& labelFormatter, const ValueFormatter& valueFormatter) const override { std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; } - /** to graphviz file */ + /** Write graphviz format to stream `os`. */ void dot(std::ostream& os, const LabelFormatter& labelFormatter, const ValueFormatter& valueFormatter, bool showZero) const override { @@ -154,7 +160,7 @@ namespace gtsam { /** incremental allSame */ size_t allSame_; - typedef boost::shared_ptr ChoicePtr; + using ChoicePtr = boost::shared_ptr; public: @@ -462,6 +468,7 @@ namespace gtsam { template DecisionTree::DecisionTree(const DecisionTree& other, std::function Y_of_X) { + // Define functor for identity mapping of node label. auto L_of_L = [](const L& label) { return label; }; root_ = convertFrom(Y_of_X, L_of_L); } @@ -594,11 +601,11 @@ namespace gtsam { const typename DecisionTree::NodePtr& f, std::function L_of_M, std::function Y_of_X) const { - typedef DecisionTree MX; - typedef typename MX::Leaf MXLeaf; - typedef typename MX::Choice MXChoice; - typedef typename MX::NodePtr MXNodePtr; - typedef DecisionTree LY; + using MX = DecisionTree; + using MXLeaf = typename MX::Leaf; + using MXChoice = typename MX::Choice; + using MXNodePtr = typename MX::NodePtr; + using LY = DecisionTree; // ugliness below because apparently we can't have templated virtual functions // If leaf, apply unary conversion "op" and create a unique leaf diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index ecc3d17dce..6aa97d8ac4 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -39,6 +39,7 @@ namespace gtsam { template class GTSAM_EXPORT DecisionTree { + protected: /// Default method for comparison of two objects of type Y. static bool DefaultCompare(const Y& a, const Y& b) { return a == b; @@ -51,11 +52,11 @@ namespace gtsam { using CompareFunc = std::function; /** Handy typedefs for unary and binary function types */ - typedef std::function Unary; - typedef std::function Binary; + using Unary = std::function; + using Binary = std::function; /** A label annotated with cardinality */ - typedef std::pair LabelC; + using LabelC = std::pair; /** DTs consist of Leaf and Choice nodes, both subclasses of Node */ class Leaf; @@ -64,7 +65,7 @@ namespace gtsam { /** ------------------------ Node base class --------------------------- */ class Node { public: - typedef boost::shared_ptr Ptr; + using Ptr = boost::shared_ptr; #ifdef DT_DEBUG_MEMORY static int nrNodes; @@ -111,9 +112,9 @@ namespace gtsam { public: /** A function is a shared pointer to the root of a DT */ - typedef typename Node::Ptr NodePtr; + using NodePtr = typename Node::Ptr; - /// a DecisionTree just contains the root. TODO(dellaert): make protected. + /// A DecisionTree just contains the root. TODO(dellaert): make protected. NodePtr root_; protected: @@ -122,7 +123,16 @@ namespace gtsam { template NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; - /// Convert from a DecisionTree. + /** + * @brief Convert from a DecisionTree to DecisionTree. + * + * @tparam M The previous label type. + * @tparam X The previous node type. + * @param f The node pointer to the root of the previous DecisionTree. + * @param L_of_M Functor to convert from label type M to type L. + * @param Y_of_X Functor to convert from node type X to type Y. + * @return NodePtr + */ template NodePtr convertFrom(const typename DecisionTree::NodePtr& f, std::function L_of_M, @@ -159,12 +169,27 @@ namespace gtsam { DecisionTree(const L& label, // const DecisionTree& f0, const DecisionTree& f1); - /** Convert from a different type. */ + /** + * @brief Convert from a different node type. + * + * @tparam X The previous node type. + * @param other The DecisionTree to convert from. + * @param Y_of_X Functor to convert from node type X to type Y. + */ template DecisionTree(const DecisionTree& other, std::function Y_of_X); - /** Convert from a different type, also transate labels via map. */ + /** + * @brief Convert from a different node type X to node type Y, also transate + * labels via map from type M to L. + * + * @tparam M Previous label type. + * @tparam X Previous node type. + * @param other The decision tree to convert. + * @param L_of_M Map from label type M to type L. + * @param Y_of_X Functor to convert from type X to type Y. + */ template DecisionTree(const DecisionTree& other, const std::map& L_of_M, std::function Y_of_X); @@ -173,7 +198,13 @@ namespace gtsam { /// @name Testable /// @{ - /** GTSAM-style print */ + /** + * @brief GTSAM-style print + * + * @param s Prefix string. + * @param labelFormatter Functor to format the node label. + * @param valueFormatter Functor to format the node value. + */ void print(const std::string& s, const LabelFormatter& labelFormatter, const ValueFormatter& valueFormatter) const; @@ -189,7 +220,7 @@ namespace gtsam { virtual ~DecisionTree() { } - /** empty tree? */ + /// Check if tree is empty. bool empty() const { return !root_; } /** equality */ @@ -248,18 +279,21 @@ namespace gtsam { /** free versions of apply */ + /// Apply unary operator `op` to DecisionTree `f`. template DecisionTree apply(const DecisionTree& f, const typename DecisionTree::Unary& op) { return f.apply(op); } + /// Apply unary operator `op` to DecisionTree `f` but with node type. template DecisionTree apply(const DecisionTree& f, const std::function& op) { return f.apply(op); } + /// Apply binary operator `op` to DecisionTree `f`. template DecisionTree apply(const DecisionTree& f, const DecisionTree& g, diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index cc61a382f9..53f3c43797 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -45,15 +45,6 @@ struct Crazy { double b; }; -// bool equals(const Crazy& other, double tol = 1e-12) const { -// return a == other.a && std::abs(b - other.b) < tol; -// } - -// bool operator==(const Crazy& other) const { -// return this->equals(other); -// } -// }; - struct CrazyDecisionTree : public DecisionTree { /// print to stdout void print(const std::string& s = "") const { @@ -261,8 +252,6 @@ TEST(DT, conversion) return y != 0; }; BDT f2(f1, ordering, bool_of_int); - // f1.print("f1"); - // f2.print("f2"); // create a value Assignment