diff --git a/pyomo/util/graph.py b/pyomo/util/graph.py index 0f802accc96..80d3db0b909 100644 --- a/pyomo/util/graph.py +++ b/pyomo/util/graph.py @@ -49,9 +49,16 @@ def graph_from_pyomo(m: _BlockData, def plot_pyomo_model(m: _BlockData, include_objective: bool = True, active: bool = True, - plot_title: Optional[str] = None): + plot_title: Optional[str] = None, + bipartite_plot: bool = False, + show_plot: bool = True): graph = graph_from_pyomo(m, include_objective=include_objective, active=active) - pos_dict = nx.drawing.spring_layout(graph, seed=0) + if bipartite_plot: + left_nodes = [c for c in OrderedSet(m.component_data_objects(pe.Constraint, descend_into=True, active=active))] + left_nodes.extend(_CompNode(obj) for obj in ComponentSet(m.component_data_objects(pe.Objective, descend_into=True, active=active))) + pos_dict = nx.drawing.bipartite_layout(graph, nodes=left_nodes) + else: + pos_dict = nx.drawing.spring_layout(graph, seed=0) edge_x = list() edge_y = list() @@ -99,4 +106,5 @@ def plot_pyomo_model(m: _BlockData, fig = go.Figure(data=[edge_trace, node_trace]) if plot_title is not None: fig.update_layout(title=dict(text=plot_title)) - fig.show() + if show_plot: # this option is mostly for unit tests + fig.show() diff --git a/pyomo/util/tests/test_graph.py b/pyomo/util/tests/test_graph.py new file mode 100644 index 00000000000..c9bc7d306cf --- /dev/null +++ b/pyomo/util/tests/test_graph.py @@ -0,0 +1,26 @@ +import pyomo.environ as pe +from pyomo.common import unittest +from pyomo.common.dependencies import attempt_import +nx, nx_available = attempt_import('networkx') +plotly, plotly_available = attempt_import('plotly') + + +@unittest.skipUnless(nx_available, 'plot_pyomo_model requires networkx') +@unittest.skipUnless(plotly_available, 'plot_pyomo_model requires plotly') +class TestPlotPyomoModel(unittest.TestCase): + def test_plot_pyomo_model(self): + """ + Unfortunately, this test only ensures the code runs without errors. + It does not test for correctness. + """ + m = pe.ConcreteModel() + m.x = pe.Var(bounds=(-1, 1)) + m.y = pe.Var() + m.z = pe.Var() + m.obj = pe.Objective(expr=m.y**2 + m.z**2) + m.c1 = pe.Constraint(expr=m.y == 2*m.x + 1) + m.c2 = pe.Constraint(expr=m.z >= m.x) + from pyomo.util.graph import plot_pyomo_model + plot_pyomo_model(m, plot_title='test plot', bipartite_plot=False, show_plot=False) + plot_pyomo_model(m, plot_title='test plot', bipartite_plot=True, show_plot=False) +