From 001e66a49ae2cbd0b49a7c2ed0b73eae8ab07379 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Thu, 22 Aug 2024 16:27:51 +0100 Subject: [PATCH] feat: Bring in the pure-python renderer from guppy (#1462) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #1407. Ports @mark-koch's rendering code from guppy, since it got deleted from there in the last hugr builder update. Adds a `Hugr.render_dot` and a `Hugr.store_dot` method. I'm using `syrupy`—a snapshot pytest extension—for testing it. I'm not sure if we want to be this strict with the generated output though. --- .pre-commit-config.yaml | 10 +- hugr-py/pyproject.toml | 1 + hugr-py/src/hugr/hugr.py | 40 + hugr-py/src/hugr/render.py | 279 ++++ .../tests/__snapshots__/test_hugr_build.ambr | 1471 +++++++++++++++++ hugr-py/tests/conftest.py | 21 +- hugr-py/tests/test_hugr_build.py | 41 +- justfile | 4 + poetry.lock | 33 +- pyproject.toml | 1 + 10 files changed, 1875 insertions(+), 26 deletions(-) create mode 100644 hugr-py/src/hugr/render.py create mode 100644 hugr-py/tests/__snapshots__/test_hugr_build.ambr diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ea8003724..15a0a18e3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,16 +14,18 @@ repos: exclude: | (?x)^( specification/schema/.*| - .*.snap| - .*.snap.new| + .*\.snap| + .*\.snap\.new| + .*\.ambr| .release-please-manifest.json )$ - id: trailing-whitespace exclude: | (?x)^( specification/schema/.*| - .*.snap| - .*.snap.new + .*\.snap| + .*\.ambr| + .*\.snap\.new )$ - id: fix-byte-order-marker - id: mixed-line-ending diff --git a/hugr-py/pyproject.toml b/hugr-py/pyproject.toml index 95eccd0e8..1773c78fe 100644 --- a/hugr-py/pyproject.toml +++ b/hugr-py/pyproject.toml @@ -28,6 +28,7 @@ python = ">=3.10" pydantic = ">=2.7,<2.9" pydantic-extra-types = "^2.9.0" semver = "^3.0.2" +graphviz = { version = "^0.20.3" } [tool.poetry.group.docs] optional = true diff --git a/hugr-py/src/hugr/hugr.py b/hugr-py/src/hugr/hugr.py index 7ebf9a6a8..7631c5e90 100644 --- a/hugr-py/src/hugr/hugr.py +++ b/hugr-py/src/hugr/hugr.py @@ -35,6 +35,8 @@ from .exceptions import ParentBeforeChild if TYPE_CHECKING: + import graphviz as gv # type: ignore[import-untyped] + from hugr import ext from hugr.val import Value @@ -143,6 +145,14 @@ def nodes(self) -> Iterable[tuple[Node, NodeData]]: """Iterator over nodes of the hugr and their data.""" return self.items() + def links(self) -> Iterator[tuple[OutPort, InPort]]: + """Iterator over all the links in the HUGR. + + Returns: + Iterator of pairs of outgoing port and the incoming ports. + """ + return ((src.port, tgt.port) for src, tgt in self._links.items()) + def children(self, node: ToNode | None = None) -> list[Node]: """The child nodes of a given `node`. @@ -683,3 +693,33 @@ def load_json(cls, json_str: str) -> Hugr: json_dict = json.loads(json_str) serial = SerialHugr.load_json(json_dict) return cls.from_serial(serial) + + def render_dot(self, palette: str | None = None) -> gv.Digraph: + """Render the HUGR to a graphviz Digraph. + + Args: + palette: The palette to use for rendering. See :obj:`PALETTE` for the + included options. + + Returns: + The graphviz Digraph. + """ + from .render import DotRenderer + + return DotRenderer(palette).render(self) + + def store_dot( + self, filename: str, format: str = "svg", palette: str | None = None + ) -> None: + """Render the HUGR to a graphviz dot file. + + Args: + filename: The file to render to. + format: The format used for rendering ('pdf', 'png', etc.). + Defaults to SVG. + palette: The palette to use for rendering. See :obj:`PALETTE` for the + included options. + """ + from .render import DotRenderer + + DotRenderer(palette).store(self, filename=filename, format=format) diff --git a/hugr-py/src/hugr/render.py b/hugr-py/src/hugr/render.py new file mode 100644 index 000000000..2c1aea758 --- /dev/null +++ b/hugr-py/src/hugr/render.py @@ -0,0 +1,279 @@ +"""Visualise HUGR using graphviz.""" + +from collections.abc import Iterable +from dataclasses import dataclass + +import graphviz as gv # type: ignore[import-untyped] +from graphviz import Digraph +from typing_extensions import assert_never + +from hugr.hugr import Hugr +from hugr.tys import CFKind, ConstKind, FunctionKind, Kind, OrderKind, ValueKind + +from .node_port import InPort, Node, OutPort + + +@dataclass(frozen=True) +class Palette: + """A set of colours used for rendering.""" + + background: str + node: str + edge: str + dark: str + const: str + discard: str + node_border: str + port_border: str + + +PALETTE: dict[str, Palette] = { + "default": Palette( + background="white", + node="#ACCBF9", + edge="#1CADE4", + dark="black", + const="#77CEEF", + discard="#ff8888", + node_border="white", + port_border="#1CADE4", + ), + "nb": Palette( + background="white", + node="#7952B3", + edge="#FFC107", + dark="#343A40", + const="#7c55b4", + discard="#ff8888", + node_border="#9d80c7", + port_border="#ffd966", + ), + "zx": Palette( + background="white", + node="#629DD1", + edge="#297FD5", + dark="#112D4E", + const="#a1eea1", + discard="#ff8888", + node_border="#D8F8D8", + port_border="#E8A5A5", + ), +} + + +class DotRenderer: + """Render a HUGR to a graphviz dot file. + + Args: + palette: The palette to use for rendering. See :obj:`PALETTE` for the + included options. + """ + + palette: Palette + + def __init__(self, palette: Palette | str | None = None) -> None: + if palette is None: + palette = "default" + if isinstance(palette, str): + palette = PALETTE[palette] + self.palette = palette + + def render(self, hugr: Hugr) -> Digraph: + """Render a HUGR to a graphviz dot object.""" + graph_attr = { + "rankdir": "", + "ranksep": "0.1", + "nodesep": "0.15", + "margin": "0", + "bgcolor": self.palette.background, + } + if not (name := hugr[hugr.root].metadata.get("name", None)): + name = "" + + graph = gv.Digraph(name, strict=False) + graph.attr(**graph_attr) + + self._viz_node(hugr.root, hugr, graph) + + for src_port, tgt_port in hugr.links(): + kind = hugr.port_kind(src_port) + self._viz_link(src_port, tgt_port, kind, graph) + + return graph + + def store(self, hugr: Hugr, filename: str, format: str = "svg") -> None: + """Render a HUGR and save it to a file. + + Args: + hugr: The HUGR to render. + filename: Filename for saving the rendered graph. + format: The format used for rendering ('pdf', 'png', etc.). + Defaults to SVG. + """ + gv_graph = self.render(hugr) + gv_graph.render(filename, format=format) + + _FONTFACE = "monospace" + + _HTML_LABEL_TEMPLATE = """ + + {inputs_row} + + + + {outputs_row} +
+ + +
{node_label}{node_data}
+
+ """ + + _HTML_PORTS_ROW_TEMPLATE = """ + + + + + {port_cells} + +
+ + + """ + + _HTML_PORT_TEMPLATE = ( + '' + '{port}' + "" + ) + + _INPUT_PREFIX = "in." + _OUTPUT_PREFIX = "out." + + def _format_html_label(self, **kwargs: str) -> str: + _HTML_LABEL_DEFAULTS = { + "label_color": self.palette.dark, + "node_back_color": self.palette.node, + "inputs_row": "", + "outputs_row": "", + "border_colour": self.palette.port_border, + "border_width": "1", + "fontface": self._FONTFACE, + "fontsize": 11.0, + } + return self._HTML_LABEL_TEMPLATE.format(**{**_HTML_LABEL_DEFAULTS, **kwargs}) + + def _html_ports(self, ports: Iterable[str], id_prefix: str) -> str: + return self._HTML_PORTS_ROW_TEMPLATE.format( + port_cells="".join( + self._HTML_PORT_TEMPLATE.format( + port=port, + # differentiate input and output node identifiers + # with a prefix + port_id=id_prefix + port, + back_colour=self.palette.background, + font_colour=self.palette.dark, + border_width="1", + border_colour=self.palette.port_border, + fontface=self._FONTFACE, + ) + for port in ports + ) + ) + + def _in_port_name(self, p: InPort) -> str: + return f"{p.node.idx}:{self._INPUT_PREFIX}{p.offset}" + + def _out_port_name(self, p: OutPort) -> str: + return f"{p.node.idx}:{self._OUTPUT_PREFIX}{p.offset}" + + def _in_order_name(self, n: Node) -> str: + return f"{n.idx}:{self._INPUT_PREFIX}None" + + def _out_order_name(self, n: Node) -> str: + return f"{n.idx}:{self._OUTPUT_PREFIX}None" + + def _viz_node(self, node: Node, hugr: Hugr, graph: Digraph) -> None: + """Render a (possibly nested) node to a graphviz graph.""" + meta = hugr[node].metadata + if len(meta) > 0: + data = "

" + "
".join( + f"{key}: {value}" for key, value in meta.items() + ) + else: + data = "" + + in_ports = [str(i) for i in range(hugr.num_in_ports(node))] + out_ports = [str(i) for i in range(hugr.num_out_ports(node))] + inputs_row = ( + self._html_ports(in_ports, self._INPUT_PREFIX) if len(in_ports) > 0 else "" + ) + outputs_row = ( + self._html_ports(out_ports, self._OUTPUT_PREFIX) + if len(out_ports) > 0 + else "" + ) + + op = hugr[node].op + + if hugr.children(node): + with graph.subgraph(name=f"cluster{node.idx}") as sub: + for child in hugr.children(node): + self._viz_node(child, hugr, sub) + html_label = self._format_html_label( + node_back_color=self.palette.edge, + node_label=str(op), + node_data=data, + border_colour=self.palette.port_border, + inputs_row=inputs_row, + outputs_row=outputs_row, + ) + sub.node(f"{node.idx}", shape="plain", label=f"<{html_label}>") + sub.attr(label="", margin="10", color=self.palette.edge) + else: + html_label = self._format_html_label( + node_back_color=self.palette.node, + node_label=str(op), + node_data=data, + inputs_row=inputs_row, + outputs_row=outputs_row, + border_colour=self.palette.background, + ) + graph.node(f"{node.idx}", label=f"<{html_label}>", shape="plain") + + def _viz_link( + self, src_port: OutPort, tgt_port: InPort, kind: Kind, graph: Digraph + ) -> None: + edge_attr = { + "penwidth": "1.5", + "arrowhead": "none", + "arrowsize": "1.0", + "fontname": self._FONTFACE, + "fontsize": "9", + "fontcolor": "black", + } + + label = "" + match kind: + case ValueKind(ty): + label = str(ty) + color = self.palette.edge + case OrderKind(): + color = self.palette.dark + case ConstKind() | FunctionKind(): + color = self.palette.const + case CFKind(): + color = self.palette.dark + case _: + assert_never(kind) + + graph.edge( + self._out_port_name(src_port), + self._in_port_name(tgt_port), + label=label, + color=color, + **edge_attr, + ) diff --git a/hugr-py/tests/__snapshots__/test_hugr_build.ambr b/hugr-py/tests/__snapshots__/test_hugr_build.ambr new file mode 100644 index 000000000..dcd50978e --- /dev/null +++ b/hugr-py/tests/__snapshots__/test_hugr_build.ambr @@ -0,0 +1,1471 @@ +# serializer version: 1 +# name: test_add_op + ''' + digraph { + bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1 + subgraph cluster0 { + 1 [label=< + + + + + + + + + + +
+ + +
Input(types=[Bool])
+
+ + + + +
0
+
+ > shape=plain] + 2 [label=< + + + + + + + + + + +
+ + + + +
0
+
+ + +
Output()
+
+ > shape=plain] + 3 [label=< + + + + + + + + + + + + + + +
+ + + + +
0
+
+ + +
_NotOp()
+
+ + + + +
0
+
+ > shape=plain] + 0 [label=< + + + + + + + + + + +
+ + +
DFG(inputs=[Bool])
+
+ + + + +
0
+
+ > shape=plain] + color="#1CADE4" label="" margin=10 + } + 1:"out.0" -> 3:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 3:"out.0" -> 2:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + } + + ''' +# --- +# name: test_build_inter_graph + ''' + digraph { + bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1 + subgraph cluster0 { + 1 [label=< + + + + + + + + + + +
+ + +
Input(types=[Bool, Bool])
+
+ + + + +
01
+
+ > shape=plain] + 2 [label=< + + + + + + + + + + +
+ + + + +
01
+
+ + +
Output()
+
+ > shape=plain] + subgraph cluster3 { + 4 [label=< + + + + + + +
+ + +
Input(types=[])
+
+ > shape=plain] + 5 [label=< + + + + + + + + + + +
+ + + + +
0
+
+ + +
Output()
+
+ > shape=plain] + 6 [label=< + + + + + + + + + + + + + + +
+ + + + +
0
+
+ + +
_NotOp()
+
+ + + + +
0
+
+ > shape=plain] + 3 [label=< + + + + + + + + + + +
+ + +
DFG(inputs=[])
+
+ + + + +
0
+
+ > shape=plain] + color="#1CADE4" label="" margin=10 + } + 0 [label=< + + + + + + + + + + +
+ + +
DFG(inputs=[Bool, Bool])
+
+ + + + +
01
+
+ > shape=plain] + color="#1CADE4" label="" margin=10 + } + 1:"out.-1" -> 3:"in.-1" [label="" arrowhead=none arrowsize=1.0 color=black fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 1:"out.0" -> 6:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 6:"out.0" -> 5:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 3:"out.0" -> 2:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 1:"out.1" -> 2:"in.1" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + } + + ''' +# --- +# name: test_build_nested + ''' + digraph { + bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1 + subgraph cluster0 { + 1 [label=< + + + + + + + + + + +
+ + +
Input(types=[Bool])
+
+ + + + +
0
+
+ > shape=plain] + 2 [label=< + + + + + + + + + + +
+ + + + +
0
+
+ + +
Output()
+
+ > shape=plain] + subgraph cluster3 { + 4 [label=< + + + + + + + + + + +
+ + +
Input(types=[Bool])
+
+ + + + +
0
+
+ > shape=plain] + 5 [label=< + + + + + + + + + + +
+ + + + +
0
+
+ + +
Output()
+
+ > shape=plain] + 6 [label=< + + + + + + + + + + + + + + +
+ + + + +
0
+
+ + +
_NotOp()
+
+ + + + +
0
+
+ > shape=plain] + 3 [label=< + + + + + + + + + + + + + + +
+ + + + +
0
+
+ + +
DFG(inputs=[Bool])
+
+ + + + +
0
+
+ > shape=plain] + color="#1CADE4" label="" margin=10 + } + 0 [label=< + + + + + + + + + + +
+ + +
DFG(inputs=[Bool])
+
+ + + + +
0
+
+ > shape=plain] + color="#1CADE4" label="" margin=10 + } + 1:"out.0" -> 3:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 4:"out.0" -> 6:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 6:"out.0" -> 5:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 3:"out.0" -> 2:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + } + + ''' +# --- +# name: test_insert_nested + ''' + digraph { + bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1 + subgraph cluster0 { + 1 [label=< + + + + + + + + + + +
+ + +
Input(types=[Bool])
+
+ + + + +
0
+
+ > shape=plain] + 2 [label=< + + + + + + + + + + +
+ + + + +
0
+
+ + +
Output()
+
+ > shape=plain] + subgraph cluster3 { + 4 [label=< + + + + + + + + + + +
+ + +
Input(types=[Bool])
+
+ + + + +
0
+
+ > shape=plain] + 5 [label=< + + + + + + + + + + +
+ + + + +
0
+
+ + +
Output()
+
+ > shape=plain] + 6 [label=< + + + + + + + + + + + + + + +
+ + + + +
0
+
+ + +
_NotOp()
+
+ + + + +
0
+
+ > shape=plain] + 3 [label=< + + + + + + + + + + + + + + +
+ + + + +
0
+
+ + +
DFG(inputs=[Bool])
+
+ + + + +
0
+
+ > shape=plain] + color="#1CADE4" label="" margin=10 + } + 0 [label=< + + + + + + + + + + +
+ + +
DFG(inputs=[Bool])
+
+ + + + +
0
+
+ > shape=plain] + color="#1CADE4" label="" margin=10 + } + 4:"out.0" -> 6:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 6:"out.0" -> 5:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 1:"out.0" -> 3:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 3:"out.0" -> 2:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + } + + ''' +# --- +# name: test_metadata + ''' + digraph simple_id { + bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1 + subgraph cluster0 { + 1 [label=< + + + + + + + + + + +
+ + +
Input(types=[Bool])
+
+ + + + +
0
+
+ > shape=plain] + 2 [label=< + + + + + + + + + + +
+ + + + +
0
+
+ + +
Output()
+
+ > shape=plain] + 3 [label=< + + + + + + + + + + + + + + +
+ + + + +
0
+
+ + +
_NotOp()

name: not
+
+ + + + +
0
+
+ > shape=plain] + 0 [label=< + + + + + + + + + + +
+ + +
DFG(inputs=[Bool])

name: simple_id
+
+ + + + +
0
+
+ > shape=plain] + color="#1CADE4" label="" margin=10 + } + 1:"out.0" -> 3:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 3:"out.0" -> 2:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + } + + ''' +# --- +# name: test_multi_out + ''' + digraph { + bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1 + subgraph cluster0 { + 1 [label=< + + + + + + + + + + +
+ + +
Input(types=[ExtType(type_def=TypeDef(name='int', description='integral value of a given bit width', params=[BoundedNatParam(upper_bound=7)], bound=ExplicitBound(bound=)), args=[BoundedNatArg(n=5)]), ExtType(type_def=TypeDef(name='int', description='integral value of a given bit width', params=[BoundedNatParam(upper_bound=7)], bound=ExplicitBound(bound=)), args=[BoundedNatArg(n=5)])])
+
+ + + + +
01
+
+ > shape=plain] + 2 [label=< + + + + + + + + + + +
+ + + + +
01
+
+ + +
Output()
+
+ > shape=plain] + 3 [label=< + + + + + + + + + + + + + + +
+ + + + +
01
+
+ + +
_DivModDef(width=5)
+
+ + + + +
01
+
+ > shape=plain] + 0 [label=< + + + + + + + + + + +
+ + +
DFG(inputs=[ExtType(type_def=TypeDef(name='int', description='integral value of a given bit width', params=[BoundedNatParam(upper_bound=7)], bound=ExplicitBound(bound=)), args=[BoundedNatArg(n=5)]), ExtType(type_def=TypeDef(name='int', description='integral value of a given bit width', params=[BoundedNatParam(upper_bound=7)], bound=ExplicitBound(bound=)), args=[BoundedNatArg(n=5)])])
+
+ + + + +
01
+
+ > shape=plain] + color="#1CADE4" label="" margin=10 + } + 1:"out.0" -> 3:"in.0" [label="ExtType(type_def=TypeDef(name='int', description='integral value of a given bit width', params=[BoundedNatParam(upper_bound=7)], bound=ExplicitBound(bound=)), args=[BoundedNatArg(n=5)])" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 1:"out.1" -> 3:"in.1" [label="ExtType(type_def=TypeDef(name='int', description='integral value of a given bit width', params=[BoundedNatParam(upper_bound=7)], bound=ExplicitBound(bound=)), args=[BoundedNatArg(n=5)])" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 3:"out.0" -> 2:"in.0" [label="ExtType(type_def=TypeDef(name='int', description='integral value of a given bit width', params=[BoundedNatParam(upper_bound=7)], bound=ExplicitBound(bound=)), args=[BoundedNatArg(n=5)])" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 3:"out.1" -> 2:"in.1" [label="ExtType(type_def=TypeDef(name='int', description='integral value of a given bit width', params=[BoundedNatParam(upper_bound=7)], bound=ExplicitBound(bound=)), args=[BoundedNatArg(n=5)])" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + } + + ''' +# --- +# name: test_multiport + ''' + digraph { + bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1 + subgraph cluster0 { + 1 [label=< + + + + + + + + + + +
+ + +
Input(types=[Bool])
+
+ + + + +
0
+
+ > shape=plain] + 2 [label=< + + + + + + + + + + +
+ + + + +
01
+
+ + +
Output()
+
+ > shape=plain] + 0 [label=< + + + + + + + + + + +
+ + +
DFG(inputs=[Bool])
+
+ + + + +
01
+
+ > shape=plain] + color="#1CADE4" label="" margin=10 + } + 1:"out.0" -> 2:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 1:"out.0" -> 2:"in.1" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + } + + ''' +# --- +# name: test_recursive_function + ''' + digraph { + bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1 + subgraph cluster0 { + subgraph cluster1 { + 2 [label=< + + + + + + + + + + +
+ + +
Input(types=[Qubit])
+
+ + + + +
0
+
+ > shape=plain] + 3 [label=< + + + + + + + + + + +
+ + + + +
0
+
+ + +
Output()
+
+ > shape=plain] + 4 [label=< + + + + + + + + + + + + + + +
+ + + + +
01
+
+ + +
Call(signature=PolyFuncType(params=[], body=FunctionType([Qubit], [Qubit])), instantiation=FunctionType([Qubit], [Qubit]), type_args=[])
+
+ + + + +
0
+
+ > shape=plain] + 1 [label=< + + + + + + + + + + +
+ + +
FuncDefn(name='recurse', inputs=[Qubit], params=[])
+
+ + + + +
0
+
+ > shape=plain] + color="#1CADE4" label="" margin=10 + } + 0 [label=< + + + + + + +
+ + +
Module()
+
+ > shape=plain] + color="#1CADE4" label="" margin=10 + } + 1:"out.0" -> 4:"in.1" [label="" arrowhead=none arrowsize=1.0 color="#77CEEF" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 2:"out.0" -> 4:"in.0" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 4:"out.0" -> 3:"in.0" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + } + + ''' +# --- +# name: test_simple_id + ''' + digraph { + bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1 + subgraph cluster0 { + 1 [label=< + + + + + + + + + + +
+ + +
Input(types=[Qubit, Qubit])
+
+ + + + +
01
+
+ > shape=plain] + 2 [label=< + + + + + + + + + + +
+ + + + +
01
+
+ + +
Output()
+
+ > shape=plain] + 0 [label=< + + + + + + + + + + +
+ + +
DFG(inputs=[Qubit, Qubit])
+
+ + + + +
01
+
+ > shape=plain] + color="#1CADE4" label="" margin=10 + } + 1:"out.0" -> 2:"in.0" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 1:"out.1" -> 2:"in.1" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + } + + ''' +# --- +# name: test_tuple + ''' + digraph { + bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1 + subgraph cluster0 { + 1 [label=< + + + + + + + + + + +
+ + +
Input(types=[Bool, Qubit])
+
+ + + + +
01
+
+ > shape=plain] + 2 [label=< + + + + + + + + + + +
+ + + + +
01
+
+ + +
Output()
+
+ > shape=plain] + 3 [label=< + + + + + + + + + + + + + + +
+ + + + +
01
+
+ + +
MakeTuple([Bool, Qubit])
+
+ + + + +
0
+
+ > shape=plain] + 4 [label=< + + + + + + + + + + + + + + +
+ + + + +
0
+
+ + +
UnpackTuple()
+
+ + + + +
01
+
+ > shape=plain] + 0 [label=< + + + + + + + + + + +
+ + +
DFG(inputs=[Bool, Qubit])
+
+ + + + +
01
+
+ > shape=plain] + color="#1CADE4" label="" margin=10 + } + 1:"out.0" -> 3:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 1:"out.1" -> 3:"in.1" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 3:"out.0" -> 4:"in.0" [label="Tuple(Bool, Qubit)" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 4:"out.0" -> 2:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 4:"out.1" -> 2:"in.1" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + } + + ''' +# --- diff --git a/hugr-py/tests/conftest.py b/hugr-py/tests/conftest.py index af74710d3..daf1f72a2 100644 --- a/hugr-py/tests/conftest.py +++ b/hugr-py/tests/conftest.py @@ -17,6 +17,8 @@ from hugr.std.float import FLOAT_T if TYPE_CHECKING: + from syrupy.assertion import SnapshotAssertion + from hugr.ops import ComWire QUANTUM_EXT = ext.Extension("pytest.quantum,", ext.Version(0, 1, 0)) @@ -130,7 +132,20 @@ def mermaid(h: Hugr): _run_hugr_cmd(h.to_serial().to_json(), cmd) -def validate(h: Hugr | ext.Package, roundtrip: bool = True): +def validate( + h: Hugr | ext.Package, + *, + roundtrip: bool = True, + snap: SnapshotAssertion | None = None, +): + """Validate a HUGR or package. + + args: + h: The HUGR or package to validate. + roundtrip: Whether to roundtrip the HUGR through the CLI. + snapshot: A hugr render snapshot. If not None, it will be compared against the + rendered HUGR. Pass `--snapshot-update` to pytest to update the snapshot file. + """ cmd = [*_base_command(), "validate", "-"] serial = h.to_json() _run_hugr_cmd(serial, cmd) @@ -146,6 +161,10 @@ def validate(h: Hugr | ext.Package, roundtrip: bool = True): roundtrip_json = json.loads(h2.to_serial().to_json()) assert roundtrip_json == starting_json + if snap is not None: + dot = h.render_dot() + assert snap == dot.source + def _run_hugr_cmd(serial: str, cmd: list[str]): try: diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index cac950127..b9f2d997e 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -65,11 +65,12 @@ def simple_id() -> Dfg: return h -def test_simple_id(): - validate(simple_id().hugr) +def test_simple_id(snapshot): + hugr = simple_id().hugr + validate(hugr, snap=snapshot) -def test_metadata(): +def test_metadata(snapshot): h = Dfg(tys.Bool) h.metadata["name"] = "simple_id" @@ -77,10 +78,10 @@ def test_metadata(): b = h.add_op(Not, b, metadata={"name": "not"}) h.set_outputs(b) - validate(h.hugr) + validate(h.hugr, snap=snapshot) -def test_multiport(): +def test_multiport(snapshot): h = Dfg(tys.Bool) (a,) = h.inputs() h.set_outputs(a, a) @@ -100,19 +101,19 @@ def test_multiport(): ] assert list(h.hugr.linked_ports(ou_n.inp(0))) == [in_n.out(0)] - validate(h.hugr) + validate(h.hugr, snap=snapshot) -def test_add_op(): +def test_add_op(snapshot): h = Dfg(tys.Bool) (a,) = h.inputs() nt = h.add_op(Not, a) h.set_outputs(nt) - validate(h.hugr) + validate(h.hugr, snap=snapshot) -def test_tuple(): +def test_tuple(snapshot): row = [tys.Bool, tys.Qubit] h = Dfg(*row) a, b = h.inputs() @@ -120,7 +121,7 @@ def test_tuple(): a, b = h.add(ops.UnpackTuple()(t)) h.set_outputs(a, b) - validate(h.hugr) + validate(h.hugr, snap=snapshot) h1 = Dfg(*row) a, b = h1.inputs() @@ -131,12 +132,12 @@ def test_tuple(): assert h.hugr.to_serial() == h1.hugr.to_serial() -def test_multi_out(): +def test_multi_out(snapshot): h = Dfg(INT_T, INT_T) a, b = h.inputs() a, b = h.add(DivMod(a, b)) h.set_outputs(a, b) - validate(h.hugr) + validate(h.hugr, snap=snapshot) def test_insert(): @@ -152,7 +153,7 @@ def test_insert(): assert mapping == {new_h.root: Node(4)} -def test_insert_nested(): +def test_insert_nested(snapshot): h1 = Dfg(tys.Bool) (a1,) = h1.inputs() nt = h1.add(Not(a1)) @@ -163,10 +164,10 @@ def test_insert_nested(): nested = h.insert_nested(h1, a) h.set_outputs(nested) assert len(h.hugr.children(nested)) == 3 - validate(h.hugr) + validate(h.hugr, snap=snapshot) -def test_build_nested(): +def test_build_nested(snapshot): h = Dfg(tys.Bool) (a,) = h.inputs() @@ -178,10 +179,10 @@ def test_build_nested(): assert len(h.hugr.children(nested)) == 3 h.set_outputs(nested) - validate(h.hugr) + validate(h.hugr, snap=snapshot) -def test_build_inter_graph(): +def test_build_inter_graph(snapshot): h = Dfg(tys.Bool, tys.Bool) (a, b) = h.inputs() with h.add_nested() as nested: @@ -190,7 +191,7 @@ def test_build_inter_graph(): h.set_outputs(nested, b) - validate(h.hugr) + validate(h.hugr, snap=snapshot) assert _SubPort(h.input_node.out(-1)) in h.hugr._links assert h.hugr.num_outgoing(h.input_node) == 2 # doesn't count state order @@ -275,7 +276,7 @@ def test_mono_function(direct_call: bool) -> None: validate(mod.hugr) -def test_recursive_function() -> None: +def test_recursive_function(snapshot) -> None: mod = Module() f_recursive = mod.define_function("recurse", [tys.Qubit]) @@ -283,7 +284,7 @@ def test_recursive_function() -> None: call = f_recursive.call(f_recursive, f_recursive.input_node[0]) f_recursive.set_outputs(call) - validate(mod.hugr) + validate(mod.hugr, snap=snapshot) def test_invalid_recursive_function() -> None: diff --git a/justfile b/justfile index 2f8beaec6..91564d2b4 100644 --- a/justfile +++ b/justfile @@ -55,6 +55,10 @@ update-schema: poetry update poetry run python scripts/generate_schema.py specification/schema/ +# Update snapshots used in the pytest tests. +update-pytest-snapshots: + poetry run pytest --snapshot-update + # Generate serialized declarations for the standard extensions and prelude. gen-extensions: cargo run -p hugr-cli gen-extensions -o specification/std_extensions diff --git a/poetry.lock b/poetry.lock index 01c758543..ae9b5fa16 100644 --- a/poetry.lock +++ b/poetry.lock @@ -161,6 +161,22 @@ docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1 testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] typing = ["typing-extensions (>=4.8)"] +[[package]] +name = "graphviz" +version = "0.20.3" +description = "Simple Python interface for Graphviz" +optional = false +python-versions = ">=3.8" +files = [ + {file = "graphviz-0.20.3-py3-none-any.whl", hash = "sha256:81f848f2904515d8cd359cc611faba817598d2feaac4027b266aa3eda7b3dde5"}, + {file = "graphviz-0.20.3.zip", hash = "sha256:09d6bc81e6a9fa392e7ba52135a9d49f1ed62526f96499325930e87ca1b5925d"}, +] + +[package.extras] +dev = ["flake8", "pep8-naming", "tox (>=3)", "twine", "wheel"] +docs = ["sphinx (>=5,<7)", "sphinx-autodoc-typehints", "sphinx-rtd-theme"] +test = ["coverage", "pytest (>=7,<8.1)", "pytest-cov", "pytest-mock (>=3)"] + [[package]] name = "hugr" version = "0.7.0" @@ -171,6 +187,7 @@ files = [] develop = true [package.dependencies] +graphviz = "^0.20.3" pydantic = ">=2.7,<2.9" pydantic-extra-types = "^2.9.0" semver = "^3.0.2" @@ -618,6 +635,20 @@ files = [ {file = "semver-3.0.2.tar.gz", hash = "sha256:6253adb39c70f6e51afed2fa7152bcd414c411286088fb4b9effb133885ab4cc"}, ] +[[package]] +name = "syrupy" +version = "4.6.4" +description = "Pytest Snapshot Test Utility" +optional = false +python-versions = ">=3.8.1" +files = [ + {file = "syrupy-4.6.4-py3-none-any.whl", hash = "sha256:5a0e47b187d32b58555b0de6d25bc7bb875e7d60c7a41bd2721f5d44975dcf85"}, + {file = "syrupy-4.6.4.tar.gz", hash = "sha256:a6facc6a45f1cff598adacb030d9573ed62863521755abd5c5d6d665f848d6cc"}, +] + +[package.dependencies] +pytest = ">=7.0.0,<9.0.0" + [[package]] name = "toml" version = "0.10.2" @@ -674,4 +705,4 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "93e7b398cf89bd222858475275e6df051fb1ddb7b0f192bb2d8dc3c07f2c9268" +content-hash = "be6d8e498af1ca1a8034d11ae10b8eb0905b25359cfba1ec1ea916aafef973c7" diff --git a/pyproject.toml b/pyproject.toml index a8d7bbd1c..585ec98bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ pytest-cov = "^5.0.0" mypy = "^1.9.0" ruff = "^0.6.1" toml = "^0.10.0" +syrupy = "^4.6.4" [tool.poetry.group.hugr.dependencies] hugr = { path = "hugr-py", develop = true }