Skip to content

Commit

Permalink
Merge pull request #1353 from borglab/feature/evaluate_wrappers
Browse files Browse the repository at this point in the history
Added convenience constructors and python wrappers
  • Loading branch information
varunagrawal authored Dec 29, 2022
2 parents a849eab + 1eb6fc7 commit 706a8a4
Show file tree
Hide file tree
Showing 12 changed files with 148 additions and 49 deletions.
21 changes: 18 additions & 3 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,25 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// Add HybridConditional to Bayes Net
using Base::add;

/// Add a Gaussian Mixture to the Bayes Net.
template <typename... T>
void addMixture(T &&...args) {
push_back(HybridConditional(
boost::make_shared<GaussianMixture>(std::forward<T>(args)...)));
}

/// Add a Gaussian conditional to the Bayes Net.
template <typename... T>
void addGaussian(T &&...args) {
push_back(HybridConditional(
boost::make_shared<GaussianConditional>(std::forward<T>(args)...)));
}

/// Add a discrete conditional to the Bayes Net.
void add(const DiscreteKey &key, const std::string &table) {
push_back(
HybridConditional(boost::make_shared<DiscreteConditional>(key, table)));
template <typename... T>
void addDiscrete(T &&...args) {
push_back(HybridConditional(
boost::make_shared<DiscreteConditional>(std::forward<T>(args)...)));
}

using Base::push_back;
Expand Down
16 changes: 15 additions & 1 deletion gtsam/hybrid/hybrid.i
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ class HybridBayesTreeClique {
// double evaluate(const gtsam::HybridValues& values) const;
};

#include <gtsam/hybrid/HybridBayesTree.h>
class HybridBayesTree {
HybridBayesTree();
void print(string s = "HybridBayesTree\n",
Expand All @@ -105,14 +104,29 @@ class HybridBayesTree {
gtsam::DefaultKeyFormatter) const;
};

#include <gtsam/hybrid/HybridBayesTree.h>
class HybridBayesNet {
HybridBayesNet();
void add(const gtsam::HybridConditional& s);
void addMixture(const gtsam::GaussianMixture& s);
void addGaussian(const gtsam::GaussianConditional& s);
void addDiscrete(const gtsam::DiscreteConditional& s);
void addDiscrete(const gtsam::DiscreteKey& key, string spec);
void addDiscrete(const gtsam::DiscreteKey& key,
const gtsam::DiscreteKeys& parents, string spec);
void addDiscrete(const gtsam::DiscreteKey& key,
const std::vector<gtsam::DiscreteKey>& parents, string spec);

bool empty() const;
size_t size() const;
gtsam::KeySet keys() const;
const gtsam::HybridConditional* at(size_t i) const;

double evaluate(const gtsam::HybridValues& x) const;
gtsam::HybridValues optimize() const;
gtsam::HybridValues sample(const gtsam::HybridValues &given) const;
gtsam::HybridValues sample() const;

void print(string s = "HybridBayesNet\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
Expand Down
19 changes: 7 additions & 12 deletions gtsam/hybrid/tests/testHybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ static const DiscreteKey Asia(asiaKey, 2);
// Test creation of a pure discrete Bayes net.
TEST(HybridBayesNet, Creation) {
HybridBayesNet bayesNet;
bayesNet.add(Asia, "99/1");
bayesNet.addDiscrete(Asia, "99/1");

DiscreteConditional expected(Asia, "99/1");
CHECK(bayesNet.atDiscrete(0));
Expand All @@ -54,7 +54,7 @@ TEST(HybridBayesNet, Creation) {
// Test adding a Bayes net to another one.
TEST(HybridBayesNet, Add) {
HybridBayesNet bayesNet;
bayesNet.add(Asia, "99/1");
bayesNet.addDiscrete(Asia, "99/1");

HybridBayesNet other;
other.push_back(bayesNet);
Expand All @@ -65,7 +65,7 @@ TEST(HybridBayesNet, Add) {
// Test evaluate for a pure discrete Bayes net P(Asia).
TEST(HybridBayesNet, evaluatePureDiscrete) {
HybridBayesNet bayesNet;
bayesNet.add(Asia, "99/1");
bayesNet.addDiscrete(Asia, "99/1");
HybridValues values;
values.insert(asiaKey, 0);
EXPECT_DOUBLES_EQUAL(0.99, bayesNet.evaluate(values), 1e-9);
Expand All @@ -85,17 +85,12 @@ TEST(HybridBayesNet, evaluateHybrid) {
conditional1 = boost::make_shared<GaussianConditional>(
X(1), Vector1::Constant(2), I_1x1, model1);

// TODO(dellaert): creating and adding mixture is clumsy.
const auto mixture = GaussianMixture::FromConditionals(
{X(1)}, {}, {Asia}, {conditional0, conditional1});

// Create hybrid Bayes net.
HybridBayesNet bayesNet;
bayesNet.push_back(HybridConditional(
boost::make_shared<GaussianConditional>(continuousConditional)));
bayesNet.push_back(
HybridConditional(boost::make_shared<GaussianMixture>(mixture)));
bayesNet.add(Asia, "99/1");
bayesNet.addGaussian(continuousConditional);
bayesNet.addMixture(GaussianMixture::FromConditionals(
{X(1)}, {}, {Asia}, {conditional0, conditional1}));
bayesNet.addDiscrete(Asia, "99/1");

// Create values at which to evaluate.
HybridValues values;
Expand Down
3 changes: 2 additions & 1 deletion gtsam/linear/GaussianBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ namespace gtsam {
return optimize(solution);
}

VectorValues GaussianBayesNet::optimize(VectorValues solution) const {
VectorValues GaussianBayesNet::optimize(const VectorValues& given) const {
VectorValues solution = given;
// (R*x)./sigmas = y by solving x=inv(R)*(y.*sigmas)
// solve each node in reverse topological sort order (parents first)
for (auto cg : boost::adaptors::reverse(*this)) {
Expand Down
2 changes: 1 addition & 1 deletion gtsam/linear/GaussianBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ namespace gtsam {
VectorValues optimize() const;

/// Version of optimize for incomplete BayesNet, given missing variables
VectorValues optimize(const VectorValues given) const;
VectorValues optimize(const VectorValues& given) const;

/**
* Sample using ancestral sampling
Expand Down
11 changes: 8 additions & 3 deletions gtsam/linear/linear.i
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,8 @@ virtual class GaussianConditional : gtsam::JacobianFactor {
bool equals(const gtsam::GaussianConditional& cg, double tol) const;

// Standard Interface
double evaluate(const gtsam::VectorValues& x) const;
double logDensity(const gtsam::VectorValues& x) const;
gtsam::Key firstFrontalKey() const;
gtsam::VectorValues solve(const gtsam::VectorValues& parents) const;
gtsam::JacobianFactor* likelihood(
Expand Down Expand Up @@ -543,17 +545,20 @@ virtual class GaussianBayesNet {
bool equals(const gtsam::GaussianBayesNet& other, double tol) const;
size_t size() const;

// Standard interface
void push_back(gtsam::GaussianConditional* conditional);
void push_back(const gtsam::GaussianBayesNet& bayesNet);
gtsam::GaussianConditional* front() const;
gtsam::GaussianConditional* back() const;

// Standard interface
double evaluate(const gtsam::VectorValues& x) const;
double logDensity(const gtsam::VectorValues& x) const;

gtsam::VectorValues optimize() const;
gtsam::VectorValues optimize(gtsam::VectorValues given) const;
gtsam::VectorValues optimize(const gtsam::VectorValues& given) const;
gtsam::VectorValues optimizeGradientSearch() const;

gtsam::VectorValues sample(gtsam::VectorValues given) const;
gtsam::VectorValues sample(const gtsam::VectorValues& given) const;
gtsam::VectorValues sample() const;
gtsam::VectorValues backSubstitute(const gtsam::VectorValues& gx) const;
gtsam::VectorValues backSubstituteTranspose(const gtsam::VectorValues& gx) const;
Expand Down
3 changes: 1 addition & 2 deletions python/gtsam/preamble/hybrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
*/
#include <pybind11/stl.h>

// NOTE: Needed since we are including pybind11/stl.h.
#ifdef GTSAM_ALLOCATOR_TBB
PYBIND11_MAKE_OPAQUE(std::vector<gtsam::Key, tbb::tbb_allocator<gtsam::Key>>);
#else
PYBIND11_MAKE_OPAQUE(std::vector<gtsam::Key>);
#endif

PYBIND11_MAKE_OPAQUE(std::vector<gtsam::GaussianFactor::shared_ptr>);
4 changes: 0 additions & 4 deletions python/gtsam/specializations/hybrid.h
Original file line number Diff line number Diff line change
@@ -1,4 +0,0 @@

py::bind_vector<std::vector<gtsam::GaussianFactor::shared_ptr> >(m_, "GaussianFactorVector");

py::implicitly_convertible<py::list, std::vector<gtsam::GaussianFactor::shared_ptr> >();
19 changes: 13 additions & 6 deletions python/gtsam/tests/test_GaussianBayesNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ def smallBayesNet():
"""Create a small Bayes Net for testing"""
bayesNet = GaussianBayesNet()
I_1x1 = np.eye(1, dtype=float)
bayesNet.push_back(GaussianConditional(
_x_, [9.0], I_1x1, _y_, I_1x1))
bayesNet.push_back(GaussianConditional(_x_, [9.0], I_1x1, _y_, I_1x1))
bayesNet.push_back(GaussianConditional(_y_, [5.0], I_1x1))
return bayesNet

Expand All @@ -41,13 +40,21 @@ class TestGaussianBayesNet(GtsamTestCase):
def test_matrix(self):
"""Test matrix method"""
R, d = smallBayesNet().matrix() # get matrix and RHS
R1 = np.array([
[1.0, 1.0],
[0.0, 1.0]])
R1 = np.array([[1.0, 1.0], [0.0, 1.0]])
d1 = np.array([9.0, 5.0])
np.testing.assert_equal(R, R1)
np.testing.assert_equal(d, d1)

def test_sample(self):
"""Test sample method"""
bayesNet = smallBayesNet()
sample = bayesNet.sample()
self.assertIsInstance(sample, gtsam.VectorValues)

if __name__ == '__main__':
# standard deviation is 1.0 for both, so we set tolerance to 3*sigma
mean = bayesNet.optimize()
self.gtsamAssertEquals(sample, mean, tol=3.0)


if __name__ == "__main__":
unittest.main()
69 changes: 69 additions & 0 deletions python/gtsam/tests/test_HybridBayesNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""
GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
Atlanta, Georgia 30332-0415
All Rights Reserved
See LICENSE for the license information
Unit tests for Hybrid Values.
Author: Frank Dellaert
"""
# pylint: disable=invalid-name, no-name-in-module, no-member

import unittest

import numpy as np
from gtsam.symbol_shorthand import A, X
from gtsam.utils.test_case import GtsamTestCase

import gtsam
from gtsam import (DiscreteKeys, GaussianConditional, GaussianMixture,
HybridBayesNet, HybridValues, noiseModel)


class TestHybridBayesNet(GtsamTestCase):
"""Unit tests for HybridValues."""
def test_evaluate(self):
"""Test evaluate for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia)."""
asiaKey = A(0)
Asia = (asiaKey, 2)

# Create the continuous conditional
I_1x1 = np.eye(1)
gc = GaussianConditional.FromMeanAndStddev(X(0), 2 * I_1x1, X(1), [-4],
5.0)

# Create the noise models
model0 = noiseModel.Diagonal.Sigmas([2.0])
model1 = noiseModel.Diagonal.Sigmas([3.0])

# Create the conditionals
conditional0 = GaussianConditional(X(1), [5], I_1x1, model0)
conditional1 = GaussianConditional(X(1), [2], I_1x1, model1)
dkeys = DiscreteKeys()
dkeys.push_back(Asia)
gm = GaussianMixture.FromConditionals([X(1)], [], dkeys,
[conditional0, conditional1]) #

# Create hybrid Bayes net.
bayesNet = HybridBayesNet()
bayesNet.addGaussian(gc)
bayesNet.addMixture(gm)
bayesNet.addDiscrete(Asia, "99/1")

# Create values at which to evaluate.
values = HybridValues()
values.insert(asiaKey, 0)
values.insert(X(0), [-6])
values.insert(X(1), [1])

conditionalProbability = gc.evaluate(values.continuous())
mixtureProbability = conditional0.evaluate(values.continuous())
self.assertAlmostEqual(conditionalProbability * mixtureProbability *
0.99,
bayesNet.evaluate(values),
places=5)


if __name__ == "__main__":
unittest.main()
15 changes: 6 additions & 9 deletions python/gtsam/tests/test_HybridFactorGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,19 @@
"""
# pylint: disable=invalid-name, no-name-in-module, no-member

from __future__ import print_function

import unittest

import gtsam
import numpy as np
from gtsam.symbol_shorthand import C, X
from gtsam.utils.test_case import GtsamTestCase

import gtsam


class TestHybridGaussianFactorGraph(GtsamTestCase):
"""Unit tests for HybridGaussianFactorGraph."""

def test_create(self):
"""Test contruction of hybrid factor graph."""
"""Test construction of hybrid factor graph."""
noiseModel = gtsam.noiseModel.Unit.Create(3)
dk = gtsam.DiscreteKeys()
dk.push_back((C(0), 2))
Expand All @@ -45,7 +43,6 @@ def test_create(self):
gtsam.Ordering.ColamdConstrainedLastHybridGaussianFactorGraph(
hfg, [C(0)]))

# print("hbn = ", hbn)
self.assertEqual(hbn.size(), 2)

mixture = hbn.at(0).inner()
Expand All @@ -56,7 +53,7 @@ def test_create(self):
self.assertIsInstance(discrete_conditional, gtsam.DiscreteConditional)

def test_optimize(self):
"""Test contruction of hybrid factor graph."""
"""Test construction of hybrid factor graph."""
noiseModel = gtsam.noiseModel.Unit.Create(3)
dk = gtsam.DiscreteKeys()
dk.push_back((C(0), 2))
Expand All @@ -73,16 +70,16 @@ def test_optimize(self):
hfg.add(jf2)
hfg.push_back(gmf)

dtf = gtsam.DecisionTreeFactor([(C(0), 2)],"0 1")
dtf = gtsam.DecisionTreeFactor([(C(0), 2)], "0 1")
hfg.add(dtf)

hbn = hfg.eliminateSequential(
gtsam.Ordering.ColamdConstrainedLastHybridGaussianFactorGraph(
hfg, [C(0)]))

# print("hbn = ", hbn)
hv = hbn.optimize()
self.assertEqual(hv.atDiscrete(C(0)), 1)


if __name__ == "__main__":
unittest.main()
15 changes: 8 additions & 7 deletions python/gtsam/tests/test_HybridValues.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
Atlanta, Georgia 30332-0415
All Rights Reserved
Expand All @@ -20,22 +20,23 @@
from gtsam.utils.test_case import GtsamTestCase


class TestHybridGaussianFactorGraph(GtsamTestCase):
class TestHybridValues(GtsamTestCase):
"""Unit tests for HybridValues."""

def test_basic(self):
"""Test contruction and basic methods of hybrid values."""
"""Test construction and basic methods of hybrid values."""

hv1 = gtsam.HybridValues()
hv1.insert(X(0), np.ones((3,1)))
hv1.insert(X(0), np.ones((3, 1)))
hv1.insert(C(0), 2)

hv2 = gtsam.HybridValues()
hv2.insert(C(0), 2)
hv2.insert(X(0), np.ones((3,1)))
hv2.insert(X(0), np.ones((3, 1)))

self.assertEqual(hv1.atDiscrete(C(0)), 2)
self.assertEqual(hv1.at(X(0))[0], np.ones((3,1))[0])
self.assertEqual(hv1.at(X(0))[0], np.ones((3, 1))[0])


if __name__ == "__main__":
unittest.main()

0 comments on commit 706a8a4

Please sign in to comment.