From 001e66a49ae2cbd0b49a7c2ed0b73eae8ab07379 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Agust=C3=ADn=20Borgna?=
<121866228+aborgna-q@users.noreply.github.com>
Date: Thu, 22 Aug 2024 16:27:51 +0100
Subject: [PATCH] feat: Bring in the pure-python renderer from guppy (#1462)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Closes #1407.
Ports @mark-koch's rendering code from guppy, since it got deleted from
there in the last hugr builder update.
Adds a `Hugr.render_dot` and a `Hugr.store_dot` method.
I'm using `syrupy`—a snapshot pytest extension—for testing it. I'm not
sure if we want to be this strict with the generated output though.
---
.pre-commit-config.yaml | 10 +-
hugr-py/pyproject.toml | 1 +
hugr-py/src/hugr/hugr.py | 40 +
hugr-py/src/hugr/render.py | 279 ++++
.../tests/__snapshots__/test_hugr_build.ambr | 1471 +++++++++++++++++
hugr-py/tests/conftest.py | 21 +-
hugr-py/tests/test_hugr_build.py | 41 +-
justfile | 4 +
poetry.lock | 33 +-
pyproject.toml | 1 +
10 files changed, 1875 insertions(+), 26 deletions(-)
create mode 100644 hugr-py/src/hugr/render.py
create mode 100644 hugr-py/tests/__snapshots__/test_hugr_build.ambr
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index ea8003724..15a0a18e3 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -14,16 +14,18 @@ repos:
exclude: |
(?x)^(
specification/schema/.*|
- .*.snap|
- .*.snap.new|
+ .*\.snap|
+ .*\.snap\.new|
+ .*\.ambr|
.release-please-manifest.json
)$
- id: trailing-whitespace
exclude: |
(?x)^(
specification/schema/.*|
- .*.snap|
- .*.snap.new
+ .*\.snap|
+ .*\.ambr|
+ .*\.snap\.new
)$
- id: fix-byte-order-marker
- id: mixed-line-ending
diff --git a/hugr-py/pyproject.toml b/hugr-py/pyproject.toml
index 95eccd0e8..1773c78fe 100644
--- a/hugr-py/pyproject.toml
+++ b/hugr-py/pyproject.toml
@@ -28,6 +28,7 @@ python = ">=3.10"
pydantic = ">=2.7,<2.9"
pydantic-extra-types = "^2.9.0"
semver = "^3.0.2"
+graphviz = { version = "^0.20.3" }
[tool.poetry.group.docs]
optional = true
diff --git a/hugr-py/src/hugr/hugr.py b/hugr-py/src/hugr/hugr.py
index 7ebf9a6a8..7631c5e90 100644
--- a/hugr-py/src/hugr/hugr.py
+++ b/hugr-py/src/hugr/hugr.py
@@ -35,6 +35,8 @@
from .exceptions import ParentBeforeChild
if TYPE_CHECKING:
+ import graphviz as gv # type: ignore[import-untyped]
+
from hugr import ext
from hugr.val import Value
@@ -143,6 +145,14 @@ def nodes(self) -> Iterable[tuple[Node, NodeData]]:
"""Iterator over nodes of the hugr and their data."""
return self.items()
+ def links(self) -> Iterator[tuple[OutPort, InPort]]:
+ """Iterator over all the links in the HUGR.
+
+ Returns:
+ Iterator of pairs of outgoing port and the incoming ports.
+ """
+ return ((src.port, tgt.port) for src, tgt in self._links.items())
+
def children(self, node: ToNode | None = None) -> list[Node]:
"""The child nodes of a given `node`.
@@ -683,3 +693,33 @@ def load_json(cls, json_str: str) -> Hugr:
json_dict = json.loads(json_str)
serial = SerialHugr.load_json(json_dict)
return cls.from_serial(serial)
+
+ def render_dot(self, palette: str | None = None) -> gv.Digraph:
+ """Render the HUGR to a graphviz Digraph.
+
+ Args:
+ palette: The palette to use for rendering. See :obj:`PALETTE` for the
+ included options.
+
+ Returns:
+ The graphviz Digraph.
+ """
+ from .render import DotRenderer
+
+ return DotRenderer(palette).render(self)
+
+ def store_dot(
+ self, filename: str, format: str = "svg", palette: str | None = None
+ ) -> None:
+ """Render the HUGR to a graphviz dot file.
+
+ Args:
+ filename: The file to render to.
+ format: The format used for rendering ('pdf', 'png', etc.).
+ Defaults to SVG.
+ palette: The palette to use for rendering. See :obj:`PALETTE` for the
+ included options.
+ """
+ from .render import DotRenderer
+
+ DotRenderer(palette).store(self, filename=filename, format=format)
diff --git a/hugr-py/src/hugr/render.py b/hugr-py/src/hugr/render.py
new file mode 100644
index 000000000..2c1aea758
--- /dev/null
+++ b/hugr-py/src/hugr/render.py
@@ -0,0 +1,279 @@
+"""Visualise HUGR using graphviz."""
+
+from collections.abc import Iterable
+from dataclasses import dataclass
+
+import graphviz as gv # type: ignore[import-untyped]
+from graphviz import Digraph
+from typing_extensions import assert_never
+
+from hugr.hugr import Hugr
+from hugr.tys import CFKind, ConstKind, FunctionKind, Kind, OrderKind, ValueKind
+
+from .node_port import InPort, Node, OutPort
+
+
+@dataclass(frozen=True)
+class Palette:
+ """A set of colours used for rendering."""
+
+ background: str
+ node: str
+ edge: str
+ dark: str
+ const: str
+ discard: str
+ node_border: str
+ port_border: str
+
+
+PALETTE: dict[str, Palette] = {
+ "default": Palette(
+ background="white",
+ node="#ACCBF9",
+ edge="#1CADE4",
+ dark="black",
+ const="#77CEEF",
+ discard="#ff8888",
+ node_border="white",
+ port_border="#1CADE4",
+ ),
+ "nb": Palette(
+ background="white",
+ node="#7952B3",
+ edge="#FFC107",
+ dark="#343A40",
+ const="#7c55b4",
+ discard="#ff8888",
+ node_border="#9d80c7",
+ port_border="#ffd966",
+ ),
+ "zx": Palette(
+ background="white",
+ node="#629DD1",
+ edge="#297FD5",
+ dark="#112D4E",
+ const="#a1eea1",
+ discard="#ff8888",
+ node_border="#D8F8D8",
+ port_border="#E8A5A5",
+ ),
+}
+
+
+class DotRenderer:
+ """Render a HUGR to a graphviz dot file.
+
+ Args:
+ palette: The palette to use for rendering. See :obj:`PALETTE` for the
+ included options.
+ """
+
+ palette: Palette
+
+ def __init__(self, palette: Palette | str | None = None) -> None:
+ if palette is None:
+ palette = "default"
+ if isinstance(palette, str):
+ palette = PALETTE[palette]
+ self.palette = palette
+
+ def render(self, hugr: Hugr) -> Digraph:
+ """Render a HUGR to a graphviz dot object."""
+ graph_attr = {
+ "rankdir": "",
+ "ranksep": "0.1",
+ "nodesep": "0.15",
+ "margin": "0",
+ "bgcolor": self.palette.background,
+ }
+ if not (name := hugr[hugr.root].metadata.get("name", None)):
+ name = ""
+
+ graph = gv.Digraph(name, strict=False)
+ graph.attr(**graph_attr)
+
+ self._viz_node(hugr.root, hugr, graph)
+
+ for src_port, tgt_port in hugr.links():
+ kind = hugr.port_kind(src_port)
+ self._viz_link(src_port, tgt_port, kind, graph)
+
+ return graph
+
+ def store(self, hugr: Hugr, filename: str, format: str = "svg") -> None:
+ """Render a HUGR and save it to a file.
+
+ Args:
+ hugr: The HUGR to render.
+ filename: Filename for saving the rendered graph.
+ format: The format used for rendering ('pdf', 'png', etc.).
+ Defaults to SVG.
+ """
+ gv_graph = self.render(hugr)
+ gv_graph.render(filename, format=format)
+
+ _FONTFACE = "monospace"
+
+ _HTML_LABEL_TEMPLATE = """
+
+ {inputs_row}
+
+
+
+ {node_label}{node_data} |
+
+ |
+
+ {outputs_row}
+
+ """
+
+ _HTML_PORTS_ROW_TEMPLATE = """
+
+
+
+ |
+
+ """
+
+ _HTML_PORT_TEMPLATE = (
+ ''
+ '{port}'
+ " | "
+ )
+
+ _INPUT_PREFIX = "in."
+ _OUTPUT_PREFIX = "out."
+
+ def _format_html_label(self, **kwargs: str) -> str:
+ _HTML_LABEL_DEFAULTS = {
+ "label_color": self.palette.dark,
+ "node_back_color": self.palette.node,
+ "inputs_row": "",
+ "outputs_row": "",
+ "border_colour": self.palette.port_border,
+ "border_width": "1",
+ "fontface": self._FONTFACE,
+ "fontsize": 11.0,
+ }
+ return self._HTML_LABEL_TEMPLATE.format(**{**_HTML_LABEL_DEFAULTS, **kwargs})
+
+ def _html_ports(self, ports: Iterable[str], id_prefix: str) -> str:
+ return self._HTML_PORTS_ROW_TEMPLATE.format(
+ port_cells="".join(
+ self._HTML_PORT_TEMPLATE.format(
+ port=port,
+ # differentiate input and output node identifiers
+ # with a prefix
+ port_id=id_prefix + port,
+ back_colour=self.palette.background,
+ font_colour=self.palette.dark,
+ border_width="1",
+ border_colour=self.palette.port_border,
+ fontface=self._FONTFACE,
+ )
+ for port in ports
+ )
+ )
+
+ def _in_port_name(self, p: InPort) -> str:
+ return f"{p.node.idx}:{self._INPUT_PREFIX}{p.offset}"
+
+ def _out_port_name(self, p: OutPort) -> str:
+ return f"{p.node.idx}:{self._OUTPUT_PREFIX}{p.offset}"
+
+ def _in_order_name(self, n: Node) -> str:
+ return f"{n.idx}:{self._INPUT_PREFIX}None"
+
+ def _out_order_name(self, n: Node) -> str:
+ return f"{n.idx}:{self._OUTPUT_PREFIX}None"
+
+ def _viz_node(self, node: Node, hugr: Hugr, graph: Digraph) -> None:
+ """Render a (possibly nested) node to a graphviz graph."""
+ meta = hugr[node].metadata
+ if len(meta) > 0:
+ data = "
" + "
".join(
+ f"{key}: {value}" for key, value in meta.items()
+ )
+ else:
+ data = ""
+
+ in_ports = [str(i) for i in range(hugr.num_in_ports(node))]
+ out_ports = [str(i) for i in range(hugr.num_out_ports(node))]
+ inputs_row = (
+ self._html_ports(in_ports, self._INPUT_PREFIX) if len(in_ports) > 0 else ""
+ )
+ outputs_row = (
+ self._html_ports(out_ports, self._OUTPUT_PREFIX)
+ if len(out_ports) > 0
+ else ""
+ )
+
+ op = hugr[node].op
+
+ if hugr.children(node):
+ with graph.subgraph(name=f"cluster{node.idx}") as sub:
+ for child in hugr.children(node):
+ self._viz_node(child, hugr, sub)
+ html_label = self._format_html_label(
+ node_back_color=self.palette.edge,
+ node_label=str(op),
+ node_data=data,
+ border_colour=self.palette.port_border,
+ inputs_row=inputs_row,
+ outputs_row=outputs_row,
+ )
+ sub.node(f"{node.idx}", shape="plain", label=f"<{html_label}>")
+ sub.attr(label="", margin="10", color=self.palette.edge)
+ else:
+ html_label = self._format_html_label(
+ node_back_color=self.palette.node,
+ node_label=str(op),
+ node_data=data,
+ inputs_row=inputs_row,
+ outputs_row=outputs_row,
+ border_colour=self.palette.background,
+ )
+ graph.node(f"{node.idx}", label=f"<{html_label}>", shape="plain")
+
+ def _viz_link(
+ self, src_port: OutPort, tgt_port: InPort, kind: Kind, graph: Digraph
+ ) -> None:
+ edge_attr = {
+ "penwidth": "1.5",
+ "arrowhead": "none",
+ "arrowsize": "1.0",
+ "fontname": self._FONTFACE,
+ "fontsize": "9",
+ "fontcolor": "black",
+ }
+
+ label = ""
+ match kind:
+ case ValueKind(ty):
+ label = str(ty)
+ color = self.palette.edge
+ case OrderKind():
+ color = self.palette.dark
+ case ConstKind() | FunctionKind():
+ color = self.palette.const
+ case CFKind():
+ color = self.palette.dark
+ case _:
+ assert_never(kind)
+
+ graph.edge(
+ self._out_port_name(src_port),
+ self._in_port_name(tgt_port),
+ label=label,
+ color=color,
+ **edge_attr,
+ )
diff --git a/hugr-py/tests/__snapshots__/test_hugr_build.ambr b/hugr-py/tests/__snapshots__/test_hugr_build.ambr
new file mode 100644
index 000000000..dcd50978e
--- /dev/null
+++ b/hugr-py/tests/__snapshots__/test_hugr_build.ambr
@@ -0,0 +1,1471 @@
+# serializer version: 1
+# name: test_add_op
+ '''
+ digraph {
+ bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1
+ subgraph cluster0 {
+ 1 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 2 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 3 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 0 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ color="#1CADE4" label="" margin=10
+ }
+ 1:"out.0" -> 3:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ 3:"out.0" -> 2:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ }
+
+ '''
+# ---
+# name: test_build_inter_graph
+ '''
+ digraph {
+ bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1
+ subgraph cluster0 {
+ 1 [label=<
+
+
+
+
+
+ Input(types=[Bool, Bool]) |
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 2 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ subgraph cluster3 {
+ 4 [label=<
+
+ > shape=plain]
+ 5 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 6 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 3 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ color="#1CADE4" label="" margin=10
+ }
+ 0 [label=<
+
+
+
+
+
+ DFG(inputs=[Bool, Bool]) |
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ color="#1CADE4" label="" margin=10
+ }
+ 1:"out.-1" -> 3:"in.-1" [label="" arrowhead=none arrowsize=1.0 color=black fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ 1:"out.0" -> 6:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ 6:"out.0" -> 5:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ 3:"out.0" -> 2:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ 1:"out.1" -> 2:"in.1" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ }
+
+ '''
+# ---
+# name: test_build_nested
+ '''
+ digraph {
+ bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1
+ subgraph cluster0 {
+ 1 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 2 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ subgraph cluster3 {
+ 4 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 5 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 6 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 3 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ color="#1CADE4" label="" margin=10
+ }
+ 0 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ color="#1CADE4" label="" margin=10
+ }
+ 1:"out.0" -> 3:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ 4:"out.0" -> 6:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ 6:"out.0" -> 5:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ 3:"out.0" -> 2:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ }
+
+ '''
+# ---
+# name: test_insert_nested
+ '''
+ digraph {
+ bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1
+ subgraph cluster0 {
+ 1 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 2 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ subgraph cluster3 {
+ 4 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 5 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 6 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 3 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ color="#1CADE4" label="" margin=10
+ }
+ 0 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ color="#1CADE4" label="" margin=10
+ }
+ 4:"out.0" -> 6:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ 6:"out.0" -> 5:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ 1:"out.0" -> 3:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ 3:"out.0" -> 2:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ }
+
+ '''
+# ---
+# name: test_metadata
+ '''
+ digraph simple_id {
+ bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1
+ subgraph cluster0 {
+ 1 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 2 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 3 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 0 [label=<
+
+
+
+
+
+ DFG(inputs=[Bool])
name: simple_id |
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ color="#1CADE4" label="" margin=10
+ }
+ 1:"out.0" -> 3:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ 3:"out.0" -> 2:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ }
+
+ '''
+# ---
+# name: test_multi_out
+ '''
+ digraph {
+ bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1
+ subgraph cluster0 {
+ 1 [label=<
+
+
+
+
+
+ Input(types=[ExtType(type_def=TypeDef(name='int', description='integral value of a given bit width', params=[BoundedNatParam(upper_bound=7)], bound=ExplicitBound(bound=)), args=[BoundedNatArg(n=5)]), ExtType(type_def=TypeDef(name='int', description='integral value of a given bit width', params=[BoundedNatParam(upper_bound=7)], bound=ExplicitBound(bound=)), args=[BoundedNatArg(n=5)])]) |
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 2 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 3 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 0 [label=<
+
+
+
+
+
+ DFG(inputs=[ExtType(type_def=TypeDef(name='int', description='integral value of a given bit width', params=[BoundedNatParam(upper_bound=7)], bound=ExplicitBound(bound=)), args=[BoundedNatArg(n=5)]), ExtType(type_def=TypeDef(name='int', description='integral value of a given bit width', params=[BoundedNatParam(upper_bound=7)], bound=ExplicitBound(bound=)), args=[BoundedNatArg(n=5)])]) |
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ color="#1CADE4" label="" margin=10
+ }
+ 1:"out.0" -> 3:"in.0" [label="ExtType(type_def=TypeDef(name='int', description='integral value of a given bit width', params=[BoundedNatParam(upper_bound=7)], bound=ExplicitBound(bound=)), args=[BoundedNatArg(n=5)])" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ 1:"out.1" -> 3:"in.1" [label="ExtType(type_def=TypeDef(name='int', description='integral value of a given bit width', params=[BoundedNatParam(upper_bound=7)], bound=ExplicitBound(bound=)), args=[BoundedNatArg(n=5)])" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ 3:"out.0" -> 2:"in.0" [label="ExtType(type_def=TypeDef(name='int', description='integral value of a given bit width', params=[BoundedNatParam(upper_bound=7)], bound=ExplicitBound(bound=)), args=[BoundedNatArg(n=5)])" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ 3:"out.1" -> 2:"in.1" [label="ExtType(type_def=TypeDef(name='int', description='integral value of a given bit width', params=[BoundedNatParam(upper_bound=7)], bound=ExplicitBound(bound=)), args=[BoundedNatArg(n=5)])" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ }
+
+ '''
+# ---
+# name: test_multiport
+ '''
+ digraph {
+ bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1
+ subgraph cluster0 {
+ 1 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 2 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 0 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ color="#1CADE4" label="" margin=10
+ }
+ 1:"out.0" -> 2:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ 1:"out.0" -> 2:"in.1" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ }
+
+ '''
+# ---
+# name: test_recursive_function
+ '''
+ digraph {
+ bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1
+ subgraph cluster0 {
+ subgraph cluster1 {
+ 2 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 3 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 4 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ Call(signature=PolyFuncType(params=[], body=FunctionType([Qubit], [Qubit])), instantiation=FunctionType([Qubit], [Qubit]), type_args=[]) |
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 1 [label=<
+
+
+
+
+
+ FuncDefn(name='recurse', inputs=[Qubit], params=[]) |
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ color="#1CADE4" label="" margin=10
+ }
+ 0 [label=<
+
+ > shape=plain]
+ color="#1CADE4" label="" margin=10
+ }
+ 1:"out.0" -> 4:"in.1" [label="" arrowhead=none arrowsize=1.0 color="#77CEEF" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ 2:"out.0" -> 4:"in.0" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ 4:"out.0" -> 3:"in.0" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ }
+
+ '''
+# ---
+# name: test_simple_id
+ '''
+ digraph {
+ bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1
+ subgraph cluster0 {
+ 1 [label=<
+
+
+
+
+
+ Input(types=[Qubit, Qubit]) |
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 2 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 0 [label=<
+
+
+
+
+
+ DFG(inputs=[Qubit, Qubit]) |
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ color="#1CADE4" label="" margin=10
+ }
+ 1:"out.0" -> 2:"in.0" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ 1:"out.1" -> 2:"in.1" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ }
+
+ '''
+# ---
+# name: test_tuple
+ '''
+ digraph {
+ bgcolor=white margin=0 nodesep=0.15 rankdir="" ranksep=0.1
+ subgraph cluster0 {
+ 1 [label=<
+
+
+
+
+
+ Input(types=[Bool, Qubit]) |
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 2 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 3 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ MakeTuple([Bool, Qubit]) |
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 4 [label=<
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ 0 [label=<
+
+
+
+
+
+ DFG(inputs=[Bool, Qubit]) |
+
+ |
+
+
+
+
+
+ |
+
+
+
+ > shape=plain]
+ color="#1CADE4" label="" margin=10
+ }
+ 1:"out.0" -> 3:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ 1:"out.1" -> 3:"in.1" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ 3:"out.0" -> 4:"in.0" [label="Tuple(Bool, Qubit)" arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ 4:"out.0" -> 2:"in.0" [label=Bool arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ 4:"out.1" -> 2:"in.1" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ }
+
+ '''
+# ---
diff --git a/hugr-py/tests/conftest.py b/hugr-py/tests/conftest.py
index af74710d3..daf1f72a2 100644
--- a/hugr-py/tests/conftest.py
+++ b/hugr-py/tests/conftest.py
@@ -17,6 +17,8 @@
from hugr.std.float import FLOAT_T
if TYPE_CHECKING:
+ from syrupy.assertion import SnapshotAssertion
+
from hugr.ops import ComWire
QUANTUM_EXT = ext.Extension("pytest.quantum,", ext.Version(0, 1, 0))
@@ -130,7 +132,20 @@ def mermaid(h: Hugr):
_run_hugr_cmd(h.to_serial().to_json(), cmd)
-def validate(h: Hugr | ext.Package, roundtrip: bool = True):
+def validate(
+ h: Hugr | ext.Package,
+ *,
+ roundtrip: bool = True,
+ snap: SnapshotAssertion | None = None,
+):
+ """Validate a HUGR or package.
+
+ args:
+ h: The HUGR or package to validate.
+ roundtrip: Whether to roundtrip the HUGR through the CLI.
+ snapshot: A hugr render snapshot. If not None, it will be compared against the
+ rendered HUGR. Pass `--snapshot-update` to pytest to update the snapshot file.
+ """
cmd = [*_base_command(), "validate", "-"]
serial = h.to_json()
_run_hugr_cmd(serial, cmd)
@@ -146,6 +161,10 @@ def validate(h: Hugr | ext.Package, roundtrip: bool = True):
roundtrip_json = json.loads(h2.to_serial().to_json())
assert roundtrip_json == starting_json
+ if snap is not None:
+ dot = h.render_dot()
+ assert snap == dot.source
+
def _run_hugr_cmd(serial: str, cmd: list[str]):
try:
diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py
index cac950127..b9f2d997e 100644
--- a/hugr-py/tests/test_hugr_build.py
+++ b/hugr-py/tests/test_hugr_build.py
@@ -65,11 +65,12 @@ def simple_id() -> Dfg:
return h
-def test_simple_id():
- validate(simple_id().hugr)
+def test_simple_id(snapshot):
+ hugr = simple_id().hugr
+ validate(hugr, snap=snapshot)
-def test_metadata():
+def test_metadata(snapshot):
h = Dfg(tys.Bool)
h.metadata["name"] = "simple_id"
@@ -77,10 +78,10 @@ def test_metadata():
b = h.add_op(Not, b, metadata={"name": "not"})
h.set_outputs(b)
- validate(h.hugr)
+ validate(h.hugr, snap=snapshot)
-def test_multiport():
+def test_multiport(snapshot):
h = Dfg(tys.Bool)
(a,) = h.inputs()
h.set_outputs(a, a)
@@ -100,19 +101,19 @@ def test_multiport():
]
assert list(h.hugr.linked_ports(ou_n.inp(0))) == [in_n.out(0)]
- validate(h.hugr)
+ validate(h.hugr, snap=snapshot)
-def test_add_op():
+def test_add_op(snapshot):
h = Dfg(tys.Bool)
(a,) = h.inputs()
nt = h.add_op(Not, a)
h.set_outputs(nt)
- validate(h.hugr)
+ validate(h.hugr, snap=snapshot)
-def test_tuple():
+def test_tuple(snapshot):
row = [tys.Bool, tys.Qubit]
h = Dfg(*row)
a, b = h.inputs()
@@ -120,7 +121,7 @@ def test_tuple():
a, b = h.add(ops.UnpackTuple()(t))
h.set_outputs(a, b)
- validate(h.hugr)
+ validate(h.hugr, snap=snapshot)
h1 = Dfg(*row)
a, b = h1.inputs()
@@ -131,12 +132,12 @@ def test_tuple():
assert h.hugr.to_serial() == h1.hugr.to_serial()
-def test_multi_out():
+def test_multi_out(snapshot):
h = Dfg(INT_T, INT_T)
a, b = h.inputs()
a, b = h.add(DivMod(a, b))
h.set_outputs(a, b)
- validate(h.hugr)
+ validate(h.hugr, snap=snapshot)
def test_insert():
@@ -152,7 +153,7 @@ def test_insert():
assert mapping == {new_h.root: Node(4)}
-def test_insert_nested():
+def test_insert_nested(snapshot):
h1 = Dfg(tys.Bool)
(a1,) = h1.inputs()
nt = h1.add(Not(a1))
@@ -163,10 +164,10 @@ def test_insert_nested():
nested = h.insert_nested(h1, a)
h.set_outputs(nested)
assert len(h.hugr.children(nested)) == 3
- validate(h.hugr)
+ validate(h.hugr, snap=snapshot)
-def test_build_nested():
+def test_build_nested(snapshot):
h = Dfg(tys.Bool)
(a,) = h.inputs()
@@ -178,10 +179,10 @@ def test_build_nested():
assert len(h.hugr.children(nested)) == 3
h.set_outputs(nested)
- validate(h.hugr)
+ validate(h.hugr, snap=snapshot)
-def test_build_inter_graph():
+def test_build_inter_graph(snapshot):
h = Dfg(tys.Bool, tys.Bool)
(a, b) = h.inputs()
with h.add_nested() as nested:
@@ -190,7 +191,7 @@ def test_build_inter_graph():
h.set_outputs(nested, b)
- validate(h.hugr)
+ validate(h.hugr, snap=snapshot)
assert _SubPort(h.input_node.out(-1)) in h.hugr._links
assert h.hugr.num_outgoing(h.input_node) == 2 # doesn't count state order
@@ -275,7 +276,7 @@ def test_mono_function(direct_call: bool) -> None:
validate(mod.hugr)
-def test_recursive_function() -> None:
+def test_recursive_function(snapshot) -> None:
mod = Module()
f_recursive = mod.define_function("recurse", [tys.Qubit])
@@ -283,7 +284,7 @@ def test_recursive_function() -> None:
call = f_recursive.call(f_recursive, f_recursive.input_node[0])
f_recursive.set_outputs(call)
- validate(mod.hugr)
+ validate(mod.hugr, snap=snapshot)
def test_invalid_recursive_function() -> None:
diff --git a/justfile b/justfile
index 2f8beaec6..91564d2b4 100644
--- a/justfile
+++ b/justfile
@@ -55,6 +55,10 @@ update-schema:
poetry update
poetry run python scripts/generate_schema.py specification/schema/
+# Update snapshots used in the pytest tests.
+update-pytest-snapshots:
+ poetry run pytest --snapshot-update
+
# Generate serialized declarations for the standard extensions and prelude.
gen-extensions:
cargo run -p hugr-cli gen-extensions -o specification/std_extensions
diff --git a/poetry.lock b/poetry.lock
index 01c758543..ae9b5fa16 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -161,6 +161,22 @@ docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1
testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"]
typing = ["typing-extensions (>=4.8)"]
+[[package]]
+name = "graphviz"
+version = "0.20.3"
+description = "Simple Python interface for Graphviz"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "graphviz-0.20.3-py3-none-any.whl", hash = "sha256:81f848f2904515d8cd359cc611faba817598d2feaac4027b266aa3eda7b3dde5"},
+ {file = "graphviz-0.20.3.zip", hash = "sha256:09d6bc81e6a9fa392e7ba52135a9d49f1ed62526f96499325930e87ca1b5925d"},
+]
+
+[package.extras]
+dev = ["flake8", "pep8-naming", "tox (>=3)", "twine", "wheel"]
+docs = ["sphinx (>=5,<7)", "sphinx-autodoc-typehints", "sphinx-rtd-theme"]
+test = ["coverage", "pytest (>=7,<8.1)", "pytest-cov", "pytest-mock (>=3)"]
+
[[package]]
name = "hugr"
version = "0.7.0"
@@ -171,6 +187,7 @@ files = []
develop = true
[package.dependencies]
+graphviz = "^0.20.3"
pydantic = ">=2.7,<2.9"
pydantic-extra-types = "^2.9.0"
semver = "^3.0.2"
@@ -618,6 +635,20 @@ files = [
{file = "semver-3.0.2.tar.gz", hash = "sha256:6253adb39c70f6e51afed2fa7152bcd414c411286088fb4b9effb133885ab4cc"},
]
+[[package]]
+name = "syrupy"
+version = "4.6.4"
+description = "Pytest Snapshot Test Utility"
+optional = false
+python-versions = ">=3.8.1"
+files = [
+ {file = "syrupy-4.6.4-py3-none-any.whl", hash = "sha256:5a0e47b187d32b58555b0de6d25bc7bb875e7d60c7a41bd2721f5d44975dcf85"},
+ {file = "syrupy-4.6.4.tar.gz", hash = "sha256:a6facc6a45f1cff598adacb030d9573ed62863521755abd5c5d6d665f848d6cc"},
+]
+
+[package.dependencies]
+pytest = ">=7.0.0,<9.0.0"
+
[[package]]
name = "toml"
version = "0.10.2"
@@ -674,4 +705,4 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
-content-hash = "93e7b398cf89bd222858475275e6df051fb1ddb7b0f192bb2d8dc3c07f2c9268"
+content-hash = "be6d8e498af1ca1a8034d11ae10b8eb0905b25359cfba1ec1ea916aafef973c7"
diff --git a/pyproject.toml b/pyproject.toml
index a8d7bbd1c..585ec98bb 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -17,6 +17,7 @@ pytest-cov = "^5.0.0"
mypy = "^1.9.0"
ruff = "^0.6.1"
toml = "^0.10.0"
+syrupy = "^4.6.4"
[tool.poetry.group.hugr.dependencies]
hugr = { path = "hugr-py", develop = true }