Skip to content

Commit

Permalink
Merge pull request #1368 from borglab/hybrid/serialization
Browse files Browse the repository at this point in the history
Fixes #1366
  • Loading branch information
varunagrawal authored Jan 4, 2023
2 parents 5cdff9e + 4cb910c commit b62f397
Show file tree
Hide file tree
Showing 19 changed files with 440 additions and 46 deletions.
31 changes: 31 additions & 0 deletions gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ namespace gtsam {
*/
size_t nrAssignments_;

/// Default constructor for serialization.
Leaf() {}

/// Constructor from constant
Leaf(const Y& constant, size_t nrAssignments = 1)
: constant_(constant), nrAssignments_(nrAssignments) {}
Expand Down Expand Up @@ -154,6 +157,18 @@ namespace gtsam {
}

bool isLeaf() const override { return true; }

private:
using Base = DecisionTree<L, Y>::Node;

/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(constant_);
ar& BOOST_SERIALIZATION_NVP(nrAssignments_);
}
}; // Leaf

/****************************************************************************/
Expand All @@ -177,6 +192,9 @@ namespace gtsam {
using ChoicePtr = boost::shared_ptr<const Choice>;

public:
/// Default constructor for serialization.
Choice() {}

~Choice() override {
#ifdef DT_DEBUG_MEMORY
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
Expand Down Expand Up @@ -428,6 +446,19 @@ namespace gtsam {
r->push_back(branch->choose(label, index));
return Unique(r);
}

private:
using Base = DecisionTree<L, Y>::Node;

/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(label_);
ar& BOOST_SERIALIZATION_NVP(branches_);
ar& BOOST_SERIALIZATION_NVP(allSame_);
}
}; // Choice

/****************************************************************************/
Expand Down
19 changes: 19 additions & 0 deletions gtsam/discrete/DecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@

#pragma once

#include <gtsam/base/Testable.h>
#include <gtsam/base/types.h>
#include <gtsam/discrete/Assignment.h>

#include <boost/serialization/nvp.hpp>
#include <boost/shared_ptr.hpp>
#include <functional>
#include <iostream>
Expand Down Expand Up @@ -113,6 +115,12 @@ namespace gtsam {
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
virtual Ptr choose(const L& label, size_t index) const = 0;
virtual bool isLeaf() const = 0;

private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {}
};
/** ------------------------ Node base class --------------------------- */

Expand Down Expand Up @@ -364,8 +372,19 @@ namespace gtsam {
compose(Iterator begin, Iterator end, const L& label) const;

/// @}

private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_NVP(root_);
}
}; // DecisionTree

template <class L, class Y>
struct traits<DecisionTree<L, Y>> : public Testable<DecisionTree<L, Y>> {};

/** free versions of apply */

/// Apply unary operator `op` to DecisionTree `f`.
Expand Down
10 changes: 10 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,16 @@ namespace gtsam {
const Names& names = {}) const override;

/// @}

private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT);
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
}
};

// traits
Expand Down
9 changes: 9 additions & 0 deletions gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,15 @@ class GTSAM_EXPORT DiscreteConditional
/// Internal version of choose
DiscreteConditional::ADT choose(const DiscreteValues& given,
bool forceComplete) const;

private:
/** Serialization function */
friend class boost::serialization::access;
template <class Archive>
void serialize(Archive& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
}
};
// DiscreteConditional

Expand Down
7 changes: 3 additions & 4 deletions gtsam/discrete/tests/testDecisionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@
// #define DT_DEBUG_MEMORY
// #define GTSAM_DT_NO_PRUNING
#define DISABLE_DOT
#include <gtsam/discrete/DecisionTree-inl.h>

#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/Signature.h>

#include <CppUnitLite/TestHarness.h>

using namespace std;
using namespace gtsam;

Expand Down
1 change: 1 addition & 0 deletions gtsam/discrete/tests/testDecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/Signature.h>
Expand Down
6 changes: 3 additions & 3 deletions gtsam/discrete/tests/testDiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
* @date Feb 14, 2011
*/

#include <boost/make_shared.hpp>

#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/inference/Symbol.h>

#include <boost/make_shared.hpp>

using namespace std;
using namespace gtsam;

Expand Down Expand Up @@ -209,7 +210,6 @@ TEST(DiscreteConditional, marginals2) {
DiscreteConditional conditional(A | B = "2/2 3/1");
DiscreteConditional prior(B % "1/2");
DiscreteConditional pAB = prior * conditional;
GTSAM_PRINT(pAB);
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
DiscreteConditional actualA = pAB.marginal(A.first);
Expand Down
105 changes: 105 additions & 0 deletions gtsam/discrete/tests/testSerializationDiscrete.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */

/*
* testSerializtionDiscrete.cpp
*
* @date January 2023
* @author Varun Agrawal
*/

#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/inference/Symbol.h>

using namespace std;
using namespace gtsam;

using Tree = gtsam::DecisionTree<string, int>;

BOOST_CLASS_EXPORT_GUID(Tree, "gtsam_DecisionTreeStringInt")
BOOST_CLASS_EXPORT_GUID(Tree::Leaf, "gtsam_DecisionTreeStringInt_Leaf")
BOOST_CLASS_EXPORT_GUID(Tree::Choice, "gtsam_DecisionTreeStringInt_Choice")

BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");

using ADT = AlgebraicDecisionTree<Key>;
BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree");
BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf")
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice")

/* ****************************************************************************/
// Test DecisionTree serialization.
TEST(DiscreteSerialization, DecisionTree) {
Tree tree({{"A", 2}}, std::vector<int>{1, 2});

using namespace serializationTestHelpers;

// Object roundtrip
Tree outputObj = create<Tree>();
roundtrip<Tree>(tree, outputObj);
EXPECT(tree.equals(outputObj));

// XML roundtrip
Tree outputXml = create<Tree>();
roundtripXML<Tree>(tree, outputXml);
EXPECT(tree.equals(outputXml));

// Binary roundtrip
Tree outputBinary = create<Tree>();
roundtripBinary<Tree>(tree, outputBinary);
EXPECT(tree.equals(outputBinary));
}

/* ************************************************************************* */
// Check serialization for AlgebraicDecisionTree and the DecisionTreeFactor
TEST(DiscreteSerialization, DecisionTreeFactor) {
using namespace serializationTestHelpers;

DiscreteKey A(1, 2), B(2, 2), C(3, 2);

DecisionTreeFactor::ADT tree(A & B & C, "1 5 3 7 2 6 4 8");
EXPECT(equalsObj<DecisionTreeFactor::ADT>(tree));
EXPECT(equalsXML<DecisionTreeFactor::ADT>(tree));
EXPECT(equalsBinary<DecisionTreeFactor::ADT>(tree));

DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8");
EXPECT(equalsObj<DecisionTreeFactor>(f));
EXPECT(equalsXML<DecisionTreeFactor>(f));
EXPECT(equalsBinary<DecisionTreeFactor>(f));
}

/* ************************************************************************* */
// Check serialization for DiscreteConditional & DiscreteDistribution
TEST(DiscreteSerialization, DiscreteConditional) {
using namespace serializationTestHelpers;

DiscreteKey A(Symbol('x', 1), 3);
DiscreteConditional conditional(A % "1/2/2");

EXPECT(equalsObj<DiscreteConditional>(conditional));
EXPECT(equalsXML<DiscreteConditional>(conditional));
EXPECT(equalsBinary<DiscreteConditional>(conditional));

DiscreteDistribution P(A % "3/2/1");
EXPECT(equalsObj<DiscreteDistribution>(P));
EXPECT(equalsXML<DiscreteDistribution>(P));
EXPECT(equalsBinary<DiscreteDistribution>(P));
}

/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */
10 changes: 10 additions & 0 deletions gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,16 @@ class GTSAM_EXPORT GaussianMixture
*/
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
/// @}

private:
/** Serialization function */
friend class boost::serialization::access;
template <class Archive>
void serialize(Archive &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
ar &BOOST_SERIALIZATION_NVP(conditionals_);
}
};

/// Return the DiscreteKey vector as a set.
Expand Down
18 changes: 18 additions & 0 deletions gtsam/hybrid/GaussianMixtureFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
bool operator==(const FactorAndConstant &other) const {
return factor == other.factor && constant == other.constant;
}

private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_NVP(factor);
ar &BOOST_SERIALIZATION_NVP(constant);
}
};

/// typedef for Decision Tree of Gaussian factors and log-constant.
Expand Down Expand Up @@ -179,6 +188,15 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
return sum;
}
/// @}

private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar &BOOST_SERIALIZATION_NVP(factors_);
}
};

// traits
Expand Down
1 change: 0 additions & 1 deletion gtsam/hybrid/HybridConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const {
auto other = e->asDiscrete();
return other != nullptr && dc->equals(*other, tol);
}
return inner_->equals(*(e->inner_), tol);

return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
: !(e->inner_);
Expand Down
13 changes: 13 additions & 0 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,19 @@ class GTSAM_EXPORT HybridConditional
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
ar& BOOST_SERIALIZATION_NVP(inner_);

// register the various casts based on the type of inner_
// https://www.boost.org/doc/libs/1_80_0/libs/serialization/doc/serialization.html#runtimecasting
if (isDiscrete()) {
boost::serialization::void_cast_register<DiscreteConditional, Factor>(
static_cast<DiscreteConditional*>(NULL), static_cast<Factor*>(NULL));
} else if (isContinuous()) {
boost::serialization::void_cast_register<GaussianConditional, Factor>(
static_cast<GaussianConditional*>(NULL), static_cast<Factor*>(NULL));
} else {
boost::serialization::void_cast_register<GaussianMixture, Factor>(
static_cast<GaussianMixture*>(NULL), static_cast<Factor*>(NULL));
}
}

}; // HybridConditional
Expand Down
6 changes: 4 additions & 2 deletions gtsam/hybrid/HybridDiscreteFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf)
/* ************************************************************************ */
bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf);
// TODO(Varun) How to compare inner_ when they are abstract types?
return e != nullptr && Base::equals(*e, tol);
if (e == nullptr) return false;
if (!Base::equals(*e, tol)) return false;
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
: !(e->inner_);
}

/* ************************************************************************ */
Expand Down
Loading

0 comments on commit b62f397

Please sign in to comment.