Skip to content

Commit

Permalink
feat: Bring in the pure-python renderer from guppy (#1462)
Browse files Browse the repository at this point in the history
Closes #1407.

Ports @mark-koch's rendering code from guppy, since it got deleted from
there in the last hugr builder update.
Adds a `Hugr.render_dot` and a `Hugr.store_dot` method.

I'm using `syrupy`—a snapshot pytest extension—for testing it. I'm not
sure if we want to be this strict with the generated output though.
  • Loading branch information
aborgna-q committed Aug 22, 2024
1 parent ea9cca0 commit 001e66a
Show file tree
Hide file tree
Showing 10 changed files with 1,875 additions and 26 deletions.
10 changes: 6 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@ repos:
exclude: |
(?x)^(
specification/schema/.*|
.*.snap|
.*.snap.new|
.*\.snap|
.*\.snap\.new|
.*\.ambr|
.release-please-manifest.json
)$
- id: trailing-whitespace
exclude: |
(?x)^(
specification/schema/.*|
.*.snap|
.*.snap.new
.*\.snap|
.*\.ambr|
.*\.snap\.new
)$
- id: fix-byte-order-marker
- id: mixed-line-ending
Expand Down
1 change: 1 addition & 0 deletions hugr-py/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ python = ">=3.10"
pydantic = ">=2.7,<2.9"
pydantic-extra-types = "^2.9.0"
semver = "^3.0.2"
graphviz = { version = "^0.20.3" }

[tool.poetry.group.docs]
optional = true
Expand Down
40 changes: 40 additions & 0 deletions hugr-py/src/hugr/hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from .exceptions import ParentBeforeChild

if TYPE_CHECKING:
import graphviz as gv # type: ignore[import-untyped]

from hugr import ext
from hugr.val import Value

Expand Down Expand Up @@ -143,6 +145,14 @@ def nodes(self) -> Iterable[tuple[Node, NodeData]]:
"""Iterator over nodes of the hugr and their data."""
return self.items()

def links(self) -> Iterator[tuple[OutPort, InPort]]:
"""Iterator over all the links in the HUGR.
Returns:
Iterator of pairs of outgoing port and the incoming ports.
"""
return ((src.port, tgt.port) for src, tgt in self._links.items())

def children(self, node: ToNode | None = None) -> list[Node]:
"""The child nodes of a given `node`.
Expand Down Expand Up @@ -683,3 +693,33 @@ def load_json(cls, json_str: str) -> Hugr:
json_dict = json.loads(json_str)
serial = SerialHugr.load_json(json_dict)
return cls.from_serial(serial)

def render_dot(self, palette: str | None = None) -> gv.Digraph:
"""Render the HUGR to a graphviz Digraph.
Args:
palette: The palette to use for rendering. See :obj:`PALETTE` for the
included options.
Returns:
The graphviz Digraph.
"""
from .render import DotRenderer

return DotRenderer(palette).render(self)

def store_dot(
self, filename: str, format: str = "svg", palette: str | None = None
) -> None:
"""Render the HUGR to a graphviz dot file.
Args:
filename: The file to render to.
format: The format used for rendering ('pdf', 'png', etc.).
Defaults to SVG.
palette: The palette to use for rendering. See :obj:`PALETTE` for the
included options.
"""
from .render import DotRenderer

DotRenderer(palette).store(self, filename=filename, format=format)
279 changes: 279 additions & 0 deletions hugr-py/src/hugr/render.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
"""Visualise HUGR using graphviz."""

from collections.abc import Iterable
from dataclasses import dataclass

import graphviz as gv # type: ignore[import-untyped]
from graphviz import Digraph
from typing_extensions import assert_never

from hugr.hugr import Hugr
from hugr.tys import CFKind, ConstKind, FunctionKind, Kind, OrderKind, ValueKind

from .node_port import InPort, Node, OutPort


@dataclass(frozen=True)
class Palette:
"""A set of colours used for rendering."""

background: str
node: str
edge: str
dark: str
const: str
discard: str
node_border: str
port_border: str


PALETTE: dict[str, Palette] = {
"default": Palette(
background="white",
node="#ACCBF9",
edge="#1CADE4",
dark="black",
const="#77CEEF",
discard="#ff8888",
node_border="white",
port_border="#1CADE4",
),
"nb": Palette(
background="white",
node="#7952B3",
edge="#FFC107",
dark="#343A40",
const="#7c55b4",
discard="#ff8888",
node_border="#9d80c7",
port_border="#ffd966",
),
"zx": Palette(
background="white",
node="#629DD1",
edge="#297FD5",
dark="#112D4E",
const="#a1eea1",
discard="#ff8888",
node_border="#D8F8D8",
port_border="#E8A5A5",
),
}


class DotRenderer:
"""Render a HUGR to a graphviz dot file.
Args:
palette: The palette to use for rendering. See :obj:`PALETTE` for the
included options.
"""

palette: Palette

def __init__(self, palette: Palette | str | None = None) -> None:
if palette is None:
palette = "default"
if isinstance(palette, str):
palette = PALETTE[palette]
self.palette = palette

def render(self, hugr: Hugr) -> Digraph:
"""Render a HUGR to a graphviz dot object."""
graph_attr = {
"rankdir": "",
"ranksep": "0.1",
"nodesep": "0.15",
"margin": "0",
"bgcolor": self.palette.background,
}
if not (name := hugr[hugr.root].metadata.get("name", None)):
name = ""

graph = gv.Digraph(name, strict=False)
graph.attr(**graph_attr)

self._viz_node(hugr.root, hugr, graph)

for src_port, tgt_port in hugr.links():
kind = hugr.port_kind(src_port)
self._viz_link(src_port, tgt_port, kind, graph)

return graph

def store(self, hugr: Hugr, filename: str, format: str = "svg") -> None:
"""Render a HUGR and save it to a file.
Args:
hugr: The HUGR to render.
filename: Filename for saving the rendered graph.
format: The format used for rendering ('pdf', 'png', etc.).
Defaults to SVG.
"""
gv_graph = self.render(hugr)
gv_graph.render(filename, format=format)

_FONTFACE = "monospace"

_HTML_LABEL_TEMPLATE = """
<TABLE BORDER="{border_width}" CELLBORDER="0" CELLSPACING="1" CELLPADDING="1"
BGCOLOR="{node_back_color}" COLOR="{border_colour}">
{inputs_row}
<TR>
<TD>
<TABLE BORDER="0" CELLBORDER="0">
<TR><TD><FONT POINT-SIZE="{fontsize}" FACE="{fontface}"
COLOR="{label_color}"><B>{node_label}</B>{node_data}</FONT></TD></TR>
</TABLE>
</TD>
</TR>
{outputs_row}
</TABLE>
"""

_HTML_PORTS_ROW_TEMPLATE = """
<TR>
<TD>
<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="3" CELLPADDING="2">
<TR>
{port_cells}
</TR>
</TABLE>
</TD>
</TR>
"""

_HTML_PORT_TEMPLATE = (
'<TD BGCOLOR="{back_colour}" COLOR="{border_colour}"'
' PORT="{port_id}" BORDER="{border_width}">'
'<FONT POINT-SIZE="10.0" FACE="{fontface}" COLOR="{font_colour}">{port}</FONT>'
"</TD>"
)

_INPUT_PREFIX = "in."
_OUTPUT_PREFIX = "out."

def _format_html_label(self, **kwargs: str) -> str:
_HTML_LABEL_DEFAULTS = {
"label_color": self.palette.dark,
"node_back_color": self.palette.node,
"inputs_row": "",
"outputs_row": "",
"border_colour": self.palette.port_border,
"border_width": "1",
"fontface": self._FONTFACE,
"fontsize": 11.0,
}
return self._HTML_LABEL_TEMPLATE.format(**{**_HTML_LABEL_DEFAULTS, **kwargs})

def _html_ports(self, ports: Iterable[str], id_prefix: str) -> str:
return self._HTML_PORTS_ROW_TEMPLATE.format(
port_cells="".join(
self._HTML_PORT_TEMPLATE.format(
port=port,
# differentiate input and output node identifiers
# with a prefix
port_id=id_prefix + port,
back_colour=self.palette.background,
font_colour=self.palette.dark,
border_width="1",
border_colour=self.palette.port_border,
fontface=self._FONTFACE,
)
for port in ports
)
)

def _in_port_name(self, p: InPort) -> str:
return f"{p.node.idx}:{self._INPUT_PREFIX}{p.offset}"

def _out_port_name(self, p: OutPort) -> str:
return f"{p.node.idx}:{self._OUTPUT_PREFIX}{p.offset}"

def _in_order_name(self, n: Node) -> str:
return f"{n.idx}:{self._INPUT_PREFIX}None"

def _out_order_name(self, n: Node) -> str:
return f"{n.idx}:{self._OUTPUT_PREFIX}None"

def _viz_node(self, node: Node, hugr: Hugr, graph: Digraph) -> None:
"""Render a (possibly nested) node to a graphviz graph."""
meta = hugr[node].metadata
if len(meta) > 0:
data = "<BR/><BR/>" + "<BR/>".join(
f"{key}: {value}" for key, value in meta.items()
)
else:
data = ""

in_ports = [str(i) for i in range(hugr.num_in_ports(node))]
out_ports = [str(i) for i in range(hugr.num_out_ports(node))]
inputs_row = (
self._html_ports(in_ports, self._INPUT_PREFIX) if len(in_ports) > 0 else ""
)
outputs_row = (
self._html_ports(out_ports, self._OUTPUT_PREFIX)
if len(out_ports) > 0
else ""
)

op = hugr[node].op

if hugr.children(node):
with graph.subgraph(name=f"cluster{node.idx}") as sub:
for child in hugr.children(node):
self._viz_node(child, hugr, sub)
html_label = self._format_html_label(
node_back_color=self.palette.edge,
node_label=str(op),
node_data=data,
border_colour=self.palette.port_border,
inputs_row=inputs_row,
outputs_row=outputs_row,
)
sub.node(f"{node.idx}", shape="plain", label=f"<{html_label}>")
sub.attr(label="", margin="10", color=self.palette.edge)
else:
html_label = self._format_html_label(
node_back_color=self.palette.node,
node_label=str(op),
node_data=data,
inputs_row=inputs_row,
outputs_row=outputs_row,
border_colour=self.palette.background,
)
graph.node(f"{node.idx}", label=f"<{html_label}>", shape="plain")

def _viz_link(
self, src_port: OutPort, tgt_port: InPort, kind: Kind, graph: Digraph
) -> None:
edge_attr = {
"penwidth": "1.5",
"arrowhead": "none",
"arrowsize": "1.0",
"fontname": self._FONTFACE,
"fontsize": "9",
"fontcolor": "black",
}

label = ""
match kind:
case ValueKind(ty):
label = str(ty)
color = self.palette.edge
case OrderKind():
color = self.palette.dark
case ConstKind() | FunctionKind():
color = self.palette.const
case CFKind():
color = self.palette.dark
case _:
assert_never(kind)

graph.edge(
self._out_port_name(src_port),
self._in_port_name(tgt_port),
label=label,
color=color,
**edge_attr,
)
Loading

0 comments on commit 001e66a

Please sign in to comment.