Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename DiscretePrior -> DiscreteDistribution #1039

Merged
merged 1 commit into from
Jan 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions gtsam/discrete/DiscreteBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#pragma once

#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscretePrior.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h>

Expand Down Expand Up @@ -79,9 +79,9 @@ namespace gtsam {
// Add inherited versions of add.
using Base::add;

/** Add a DiscretePrior using a table or a string */
/** Add a DiscreteDistribution using a table or a string */
void add(const DiscreteKey& key, const std::string& spec) {
emplace_shared<DiscretePrior>(key, spec);
emplace_shared<DiscreteDistribution>(key, spec);
}

/** Add a DiscreteCondtional */
Expand Down
2 changes: 1 addition & 1 deletion gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class GTSAM_EXPORT DiscreteConditional
const std::string& spec)
: DiscreteConditional(Signature(key, parents, spec)) {}

/// No-parent specialization; can also use DiscretePrior.
/// No-parent specialization; can also use DiscreteDistribution.
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
: DiscreteConditional(Signature(key, {}, spec)) {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,23 @@
* -------------------------------------------------------------------------- */

/**
* @file DiscretePrior.cpp
* @file DiscreteDistribution.cpp
* @date December 2021
* @author Frank Dellaert
*/

#include <gtsam/discrete/DiscretePrior.h>
#include <gtsam/discrete/DiscreteDistribution.h>

#include <vector>

namespace gtsam {

void DiscretePrior::print(const std::string& s,
const KeyFormatter& formatter) const {
void DiscreteDistribution::print(const std::string& s,
const KeyFormatter& formatter) const {
Base::print(s, formatter);
}

double DiscretePrior::operator()(size_t value) const {
double DiscreteDistribution::operator()(size_t value) const {
if (nrFrontals() != 1)
throw std::invalid_argument(
"Single value operator can only be invoked on single-variable "
Expand All @@ -34,10 +36,10 @@ double DiscretePrior::operator()(size_t value) const {
return Base::operator()(values);
}

std::vector<double> DiscretePrior::pmf() const {
std::vector<double> DiscreteDistribution::pmf() const {
if (nrFrontals() != 1)
throw std::invalid_argument(
"DiscretePrior::pmf only defined for single-variable priors");
"DiscreteDistribution::pmf only defined for single-variable priors");
const size_t nrValues = cardinalities_.at(keys_[0]);
std::vector<double> array;
array.reserve(nrValues);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
* -------------------------------------------------------------------------- */

/**
* @file DiscretePrior.h
* @file DiscreteDistribution.h
* @date December 2021
* @author Frank Dellaert
*/
Expand All @@ -20,50 +20,52 @@
#include <gtsam/discrete/DiscreteConditional.h>

#include <string>
#include <vector>

namespace gtsam {

/**
* A prior probability on a set of discrete variables.
* Derives from DiscreteConditional
*/
class GTSAM_EXPORT DiscretePrior : public DiscreteConditional {
class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
public:
using Base = DiscreteConditional;

/// @name Standard Constructors
/// @{

/// Default constructor needed for serialization.
DiscretePrior() {}
DiscreteDistribution() {}

/// Constructor from factor.
DiscretePrior(const DecisionTreeFactor& f) : Base(f.size(), f) {}
explicit DiscreteDistribution(const DecisionTreeFactor& f)
: Base(f.size(), f) {}

/**
* Construct from a Signature.
*
* Example: DiscretePrior P(D % "3/2");
* Example: DiscreteDistribution P(D % "3/2");
*/
DiscretePrior(const Signature& s) : Base(s) {}
explicit DiscreteDistribution(const Signature& s) : Base(s) {}

/**
* Construct from key and a vector of floats specifying the probability mass
* function (PMF).
*
* Example: DiscretePrior P(D, {0.4, 0.6});
* Example: DiscreteDistribution P(D, {0.4, 0.6});
*/
DiscretePrior(const DiscreteKey& key, const std::vector<double>& spec)
: DiscretePrior(Signature(key, {}, Signature::Table{spec})) {}
DiscreteDistribution(const DiscreteKey& key, const std::vector<double>& spec)
: DiscreteDistribution(Signature(key, {}, Signature::Table{spec})) {}

/**
* Construct from key and a string specifying the probability mass function
* (PMF).
*
* Example: DiscretePrior P(D, "9/1 2/8 3/7 1/9");
* Example: DiscreteDistribution P(D, "9/1 2/8 3/7 1/9");
*/
DiscretePrior(const DiscreteKey& key, const std::string& spec)
: DiscretePrior(Signature(key, {}, spec)) {}
DiscreteDistribution(const DiscreteKey& key, const std::string& spec)
: DiscreteDistribution(Signature(key, {}, spec)) {}

/// @}
/// @name Testable
Expand Down Expand Up @@ -102,10 +104,10 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional {

/// @}
};
// DiscretePrior
// DiscreteDistribution

// traits
template <>
struct traits<DiscretePrior> : public Testable<DiscretePrior> {};
struct traits<DiscreteDistribution> : public Testable<DiscreteDistribution> {};

} // namespace gtsam
12 changes: 6 additions & 6 deletions gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
std::map<gtsam::Key, std::vector<std::string>> names) const;
};

#include <gtsam/discrete/DiscretePrior.h>
virtual class DiscretePrior : gtsam::DiscreteConditional {
DiscretePrior();
DiscretePrior(const gtsam::DecisionTreeFactor& f);
DiscretePrior(const gtsam::DiscreteKey& key, string spec);
DiscretePrior(const gtsam::DiscreteKey& key, std::vector<double> spec);
#include <gtsam/discrete/DiscreteDistribution.h>
virtual class DiscreteDistribution : gtsam::DiscreteConditional {
DiscreteDistribution();
DiscreteDistribution(const gtsam::DecisionTreeFactor& f);
DiscreteDistribution(const gtsam::DiscreteKey& key, string spec);
DiscreteDistribution(const gtsam::DiscreteKey& key, std::vector<double> spec);
void print(string s = "Discrete Prior\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
Expand Down
6 changes: 3 additions & 3 deletions gtsam/discrete/tests/testDecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscretePrior.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/Signature.h>

#include <boost/assign/std/map.hpp>
Expand Down Expand Up @@ -56,8 +56,8 @@ TEST( DecisionTreeFactor, constructors)
TEST(DecisionTreeFactor, multiplication) {
DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);

// Multiply with a DiscretePrior, i.e., Bayes Law!
DiscretePrior prior(v1 % "1/3");
// Multiply with a DiscreteDistribution, i.e., Bayes Law!
DiscreteDistribution prior(v1 % "1/3");
DecisionTreeFactor f1(v0 & v1, "1 2 3 4");
DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
CHECK(assert_equal(expected, static_cast<DecisionTreeFactor>(prior) * f1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,42 +11,41 @@

/*
* @file testDiscretePrior.cpp
* @brief unit tests for DiscretePrior
* @brief unit tests for DiscreteDistribution
* @author Frank dellaert
* @date December 2021
*/

#include <CppUnitLite/TestHarness.h>
#include <gtsam/discrete/DiscretePrior.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/Signature.h>

using namespace std;
using namespace gtsam;

static const DiscreteKey X(0, 2);

/* ************************************************************************* */
TEST(DiscretePrior, constructors) {
TEST(DiscreteDistribution, constructors) {
DecisionTreeFactor f(X, "0.4 0.6");
DiscretePrior expected(f);
DiscreteDistribution expected(f);

DiscretePrior actual(X % "2/3");
DiscreteDistribution actual(X % "2/3");
EXPECT_LONGS_EQUAL(1, actual.nrFrontals());
EXPECT_LONGS_EQUAL(0, actual.nrParents());
EXPECT(assert_equal(expected, actual, 1e-9));

const vector<double> pmf{0.4, 0.6};
DiscretePrior actual2(X, pmf);
const std::vector<double> pmf{0.4, 0.6};
DiscreteDistribution actual2(X, pmf);
EXPECT_LONGS_EQUAL(1, actual2.nrFrontals());
EXPECT_LONGS_EQUAL(0, actual2.nrParents());
EXPECT(assert_equal(expected, actual2, 1e-9));
}

/* ************************************************************************* */
TEST(DiscretePrior, Multiply) {
TEST(DiscreteDistribution, Multiply) {
DiscreteKey A(0, 2), B(1, 2);
DiscreteConditional conditional(A | B = "1/2 2/1");
DiscretePrior prior(B, "1/2");
DiscreteDistribution prior(B, "1/2");
DiscreteConditional actual = prior * conditional; // P(A|B) * P(B)

EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); // = P(A,B)
Expand All @@ -56,22 +55,22 @@ TEST(DiscretePrior, Multiply) {
}

/* ************************************************************************* */
TEST(DiscretePrior, operator) {
DiscretePrior prior(X % "2/3");
TEST(DiscreteDistribution, operator) {
DiscreteDistribution prior(X % "2/3");
EXPECT_DOUBLES_EQUAL(prior(0), 0.4, 1e-9);
EXPECT_DOUBLES_EQUAL(prior(1), 0.6, 1e-9);
}

/* ************************************************************************* */
TEST(DiscretePrior, pmf) {
DiscretePrior prior(X % "2/3");
vector<double> expected {0.4, 0.6};
EXPECT(prior.pmf() == expected);
TEST(DiscreteDistribution, pmf) {
DiscreteDistribution prior(X % "2/3");
std::vector<double> expected{0.4, 0.6};
EXPECT(prior.pmf() == expected);
}

/* ************************************************************************* */
TEST(DiscretePrior, sample) {
DiscretePrior prior(X % "2/3");
TEST(DiscreteDistribution, sample) {
DiscreteDistribution prior(X % "2/3");
prior.sample();
}

Expand Down
6 changes: 3 additions & 3 deletions python/gtsam/tests/test_DecisionTreeFactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import unittest

from gtsam import DecisionTreeFactor, DiscreteValues, DiscretePrior, Ordering
from gtsam import DecisionTreeFactor, DiscreteValues, DiscreteDistribution, Ordering
from gtsam.utils.test_case import GtsamTestCase


Expand All @@ -36,8 +36,8 @@ def test_multiplication(self):
v1 = (1, 2)
v2 = (2, 2)

# Multiply with a DiscretePrior, i.e., Bayes Law!
prior = DiscretePrior(v1, [1, 3])
# Multiply with a DiscreteDistribution, i.e., Bayes Law!
prior = DiscreteDistribution(v1, [1, 3])
f1 = DecisionTreeFactor([v0, v1], "1 2 3 4")
expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3")
self.gtsamAssertEquals(DecisionTreeFactor(prior) * f1, expected)
Expand Down
4 changes: 2 additions & 2 deletions python/gtsam/tests/test_DiscreteBayesNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import unittest

from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
DiscreteKeys, DiscretePrior, DiscreteValues, Ordering)
DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering)
from gtsam.utils.test_case import GtsamTestCase


Expand Down Expand Up @@ -74,7 +74,7 @@ def test_Asia(self):
for j in range(8):
ordering.push_back(j)
chordal = fg.eliminateSequential(ordering)
expected2 = DiscretePrior(Bronchitis, "11/9")
expected2 = DiscreteDistribution(Bronchitis, "11/9")
self.gtsamAssertEquals(chordal.at(7), expected2)

# solve
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import unittest

import numpy as np
from gtsam import DecisionTreeFactor, DiscreteKeys, DiscretePrior
from gtsam import DecisionTreeFactor, DiscreteKeys, DiscreteDistribution
from gtsam.utils.test_case import GtsamTestCase

X = 0, 2
Expand All @@ -28,33 +28,33 @@ def test_constructor(self):
keys = DiscreteKeys()
keys.push_back(X)
f = DecisionTreeFactor(keys, "0.4 0.6")
expected = DiscretePrior(f)
actual = DiscretePrior(X, "2/3")
expected = DiscreteDistribution(f)

actual = DiscreteDistribution(X, "2/3")
self.gtsamAssertEquals(actual, expected)
actual2 = DiscretePrior(X, [0.4, 0.6])

actual2 = DiscreteDistribution(X, [0.4, 0.6])
self.gtsamAssertEquals(actual2, expected)

def test_operator(self):
prior = DiscretePrior(X, "2/3")
prior = DiscreteDistribution(X, "2/3")
self.assertAlmostEqual(prior(0), 0.4)
self.assertAlmostEqual(prior(1), 0.6)

def test_pmf(self):
prior = DiscretePrior(X, "2/3")
prior = DiscreteDistribution(X, "2/3")
expected = np.array([0.4, 0.6])
np.testing.assert_allclose(expected, prior.pmf())

def test_sample(self):
prior = DiscretePrior(X, "2/3")
prior = DiscreteDistribution(X, "2/3")
actual = prior.sample()
self.assertIsInstance(actual, int)

def test_markdown(self):
"""Test the _repr_markdown_ method."""

prior = DiscretePrior(X, "2/3")
prior = DiscreteDistribution(X, "2/3")
expected = " *P(0):*\n\n" \
"|0|value|\n" \
"|:-:|:-:|\n" \
Expand Down