Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for positioning in BayesNet plotting #1070

Merged
merged 10 commits into from
Jan 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions gtsam/discrete/DiscreteBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@

namespace gtsam {

/** A Bayes net made from discrete conditional distributions. */
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional>
{
public:

/**
* A Bayes net made from discrete conditional distributions.
* @addtogroup discrete
*/
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
public:
typedef BayesNet<DiscreteConditional> Base;
typedef DiscreteBayesNet This;
typedef DiscreteConditional ConditionalType;
Expand All @@ -49,16 +50,20 @@ namespace gtsam {
DiscreteBayesNet() {}

/** Construct from iterator over conditionals */
template<typename ITERATOR>
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
template <typename ITERATOR>
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
: Base(firstConditional, lastConditional) {}

/** Construct from container of factors (shared_ptr or plain objects) */
template<class CONTAINER>
explicit DiscreteBayesNet(const CONTAINER& conditionals) : Base(conditionals) {}

/** Implicit copy/downcast constructor to override explicit template container constructor */
template<class DERIVEDCONDITIONAL>
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {}
template <class CONTAINER>
explicit DiscreteBayesNet(const CONTAINER& conditionals)
: Base(conditionals) {}

/** Implicit copy/downcast constructor to override explicit template
* container constructor */
template <class DERIVEDCONDITIONAL>
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
: Base(graph) {}

/// Destructor
virtual ~DiscreteBayesNet() {}
Expand Down
29 changes: 17 additions & 12 deletions gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
gtsam::Key firstFrontalKey() const;
size_t nrFrontals() const;
size_t nrParents() const;
void printSignature(
Expand Down Expand Up @@ -156,13 +157,17 @@ class DiscreteBayesNet {
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues sample() const;
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;

string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
string markdown(const gtsam::KeyFormatter& keyFormatter,
Expand Down Expand Up @@ -252,14 +257,6 @@ class DiscreteFactorGraph {
void print(string s = "") const;
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;

string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const;

gtsam::DecisionTreeFactor product() const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const;
Expand All @@ -281,6 +278,14 @@ class DiscreteFactorGraph {
std::pair<gtsam::DiscreteBayesTree, gtsam::DiscreteFactorGraph>
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);

string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;

string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
string markdown(const gtsam::KeyFormatter& keyFormatter,
Expand Down
19 changes: 14 additions & 5 deletions gtsam/discrete/tests/testDiscreteBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,21 @@ TEST(DiscreteBayesNet, Dot) {
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");

string actual = fragment.dot();
cout << actual << endl;
EXPECT(actual ==
"digraph G{\n"
"0->3\n"
"4->6\n"
"3->5\n"
"6->5\n"
"digraph {\n"
" size=\"5,5\";\n"
"\n"
" var0[label=\"0\"];\n"
" var3[label=\"3\"];\n"
" var4[label=\"4\"];\n"
" var5[label=\"5\"];\n"
" var6[label=\"6\"];\n"
"\n"
" var3->var5\n"
" var6->var5\n"
" var4->var6\n"
" var0->var3\n"
"}");
}

Expand Down
44 changes: 28 additions & 16 deletions gtsam/inference/BayesNet-inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,41 +10,51 @@
* -------------------------------------------------------------------------- */

/**
* @file BayesNet.h
* @brief Bayes network
* @author Frank Dellaert
* @author Richard Roberts
*/
* @file BayesNet.h
* @brief Bayes network
* @author Frank Dellaert
* @author Richard Roberts
*/

#pragma once

#include <gtsam/inference/FactorGraph-inst.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph-inst.h>

#include <boost/range/adaptor/reversed.hpp>
#include <fstream>
#include <string>

namespace gtsam {

/* ************************************************************************* */
template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::print(
const std::string& s, const KeyFormatter& formatter) const {
void BayesNet<CONDITIONAL>::print(const std::string& s,
const KeyFormatter& formatter) const {
Base::print(s, formatter);
}

/* ************************************************************************* */
template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::dot(std::ostream& os,
const KeyFormatter& keyFormatter) const {
os << "digraph G{\n";
const KeyFormatter& keyFormatter,
const DotWriter& writer) const {
writer.digraphPreamble(&os);

// Create nodes for each variable in the graph
for (Key key : this->keys()) {
auto position = writer.variablePos(key);
writer.drawVariable(key, keyFormatter, position, &os);
}
os << "\n";

for (auto conditional : *this) {
// Reverse order as typically Bayes nets stored in reverse topological sort.
for (auto conditional : boost::adaptors::reverse(*this)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a comment here why reversing the order? Like following a top-down order in printing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is because Bayes nets are typically specified in reverse topological sort order, that's how they come out of elimination. Will add comments in the next PR to not restart CI for this one thing

auto frontals = conditional->frontals();
const Key me = frontals.front();
auto parents = conditional->parents();
for (const Key& p : parents)
os << keyFormatter(p) << "->" << keyFormatter(me) << "\n";
os << " var" << keyFormatter(p) << "->var" << keyFormatter(me) << "\n";
}

os << "}";
Expand All @@ -53,18 +63,20 @@ void BayesNet<CONDITIONAL>::dot(std::ostream& os,

/* ************************************************************************* */
template <class CONDITIONAL>
std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter) const {
std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter,
const DotWriter& writer) const {
std::stringstream ss;
dot(ss, keyFormatter);
dot(ss, keyFormatter, writer);
return ss.str();
}

/* ************************************************************************* */
template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter) const {
const KeyFormatter& keyFormatter,
const DotWriter& writer) const {
std::ofstream of(filename.c_str());
dot(of, keyFormatter);
dot(of, keyFormatter, writer);
of.close();
}

Expand Down
104 changes: 53 additions & 51 deletions gtsam/inference/BayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,77 +10,79 @@
* -------------------------------------------------------------------------- */

/**
* @file BayesNet.h
* @brief Bayes network
* @author Frank Dellaert
* @author Richard Roberts
*/
* @file BayesNet.h
* @brief Bayes network
* @author Frank Dellaert
* @author Richard Roberts
*/

#pragma once

#include <boost/shared_ptr.hpp>

#include <gtsam/inference/FactorGraph.h>

namespace gtsam {
#include <boost/shared_ptr.hpp>
#include <string>

/**
* A BayesNet is a tree of conditionals, stored in elimination order.
*
* todo: how to handle Bayes nets with an optimize function? Currently using global functions.
* \nosubgrouping
*/
template<class CONDITIONAL>
class BayesNet : public FactorGraph<CONDITIONAL> {
namespace gtsam {

private:
/**
* A BayesNet is a tree of conditionals, stored in elimination order.
* @addtogroup inference
*/
template <class CONDITIONAL>
class BayesNet : public FactorGraph<CONDITIONAL> {
private:
typedef FactorGraph<CONDITIONAL> Base;

typedef FactorGraph<CONDITIONAL> Base;
public:
typedef typename boost::shared_ptr<CONDITIONAL>
sharedConditional; ///< A shared pointer to a conditional

public:
typedef typename boost::shared_ptr<CONDITIONAL> sharedConditional; ///< A shared pointer to a conditional
protected:
/// @name Standard Constructors
/// @{

protected:
/// @name Standard Constructors
/// @{
/** Default constructor as an empty BayesNet */
BayesNet() {}

/** Default constructor as an empty BayesNet */
BayesNet() {};
/** Construct from iterator over conditionals */
template <typename ITERATOR>
BayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
: Base(firstConditional, lastConditional) {}

/** Construct from iterator over conditionals */
template<typename ITERATOR>
BayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
/// @}

/// @}
public:
/// @name Testable
/// @{

public:
/// @name Testable
/// @{
/** print out graph */
void print(
const std::string& s = "BayesNet",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;

/** print out graph */
void print(
const std::string& s = "BayesNet",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// @}

/// @}
/// @name Graph Display
/// @{

/// @name Graph Display
/// @{
/// Output to graphviz format, stream version.
void dot(std::ostream& os,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const;

/// Output to graphviz format, stream version.
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// Output to graphviz format string.
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const;

/// Output to graphviz format string.
std::string dot(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// output to file with graphviz format.
void saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const;

/// output to file with graphviz format.
void saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;

/// @}
};
/// @}
};

}
} // namespace gtsam

#include <gtsam/inference/BayesNet-inst.h>
Loading