Skip to content

Commit

Permalink
refactor!: Make _DfBase and _DefinitionBuilder public (#1461)
Browse files Browse the repository at this point in the history
Closes #1441.

drive-by: Make `Hugr` covariant.

The plan was to make `ParentBuilder` / `DfBase` covariant too, but they
_must_ be invariant since they contain a mutable `Hugr[OpType]`
attribute.

BREAKING CHANGE: Renamed `_DfBase` to `DfBase` and `_DefinitionBuilder`
to `DefinitionBuilder`
  • Loading branch information
aborgna-q committed Aug 22, 2024
1 parent c7cd840 commit ea9cca0
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 17 deletions.
4 changes: 2 additions & 2 deletions hugr-py/src/hugr/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from hugr import ops, tys, val

from .dfg import _DfBase
from .dfg import DfBase
from .exceptions import MismatchedExit, NoSiblingAncestor, NotInSameCfg
from .hugr import Hugr, ParentBuilder

Expand All @@ -19,7 +19,7 @@
from .tys import Type, TypeRow


class Block(_DfBase[ops.DataflowBlock]):
class Block(DfBase[ops.DataflowBlock]):
"""Builder class for a basic block in a HUGR control flow graph."""

def set_outputs(self, *outputs: Wire) -> None:
Expand Down
6 changes: 3 additions & 3 deletions hugr-py/src/hugr/cond_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
from hugr import ops
from hugr.tys import Sum

from .dfg import _DfBase
from .dfg import DfBase
from .hugr import Hugr, ParentBuilder

if TYPE_CHECKING:
from .node_port import Node, ToNode, Wire
from .tys import TypeRow


class Case(_DfBase[ops.Case]):
class Case(DfBase[ops.Case]):
"""Dataflow graph builder for a case in a conditional."""

_parent_cond: Conditional | None = None
Expand Down Expand Up @@ -202,7 +202,7 @@ def add_case(self, case_id: int) -> Case:


@dataclass
class TailLoop(_DfBase[ops.TailLoop]):
class TailLoop(DfBase[ops.TailLoop]):
"""Builder for a tail-controlled loop.
Args:
Expand Down
8 changes: 4 additions & 4 deletions hugr-py/src/hugr/dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


@dataclass()
class _DefinitionBuilder(Generic[OpVar]):
class DefinitionBuilder(Generic[OpVar]):
"""Base class for builders that can define functions, constants, and aliases.
As this class may be a root node, it does not extend `ParentBuilder`.
Expand Down Expand Up @@ -95,7 +95,7 @@ def add_alias_defn(self, name: str, ty: Type, parent: ToNode | None = None) -> N


@dataclass()
class _DfBase(ParentBuilder[DP], _DefinitionBuilder, AbstractContextManager):
class DfBase(ParentBuilder[DP], DefinitionBuilder, AbstractContextManager):
"""Base class for dataflow graph builders.
Args:
Expand Down Expand Up @@ -636,7 +636,7 @@ def _wire_up_port(self, node: Node, offset: PortOffset, p: Wire) -> tys.Type:
return self._get_dataflow_type(src)


class Dfg(_DfBase[ops.DFG]):
class Dfg(DfBase[ops.DFG]):
"""Builder for a simple nested Dataflow graph, with root node of type
:class:`DFG <hugr.ops.DFG>`.
Expand Down Expand Up @@ -672,7 +672,7 @@ def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None:


@dataclass
class Function(_DfBase[ops.FuncDefn]):
class Function(DfBase[ops.FuncDefn]):
"""Build a function definition as a HUGR dataflow graph.
Args:
Expand Down
4 changes: 2 additions & 2 deletions hugr-py/src/hugr/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import TYPE_CHECKING

from . import ops
from .dfg import Function, _DefinitionBuilder
from .dfg import DefinitionBuilder, Function
from .hugr import Hugr

if TYPE_CHECKING:
Expand All @@ -17,7 +17,7 @@


@dataclass
class Module(_DefinitionBuilder[ops.Module]):
class Module(DefinitionBuilder[ops.Module]):
"""Build a top-level HUGR module.
Examples:
Expand Down
12 changes: 6 additions & 6 deletions hugr-py/src/hugr/hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def to_serial(self, node: Node) -> SerialOp:
P = TypeVar("P", InPort, OutPort)
K = TypeVar("K", InPort, OutPort)
OpVar = TypeVar("OpVar", bound=Op)
OpVar2 = TypeVar("OpVar2", bound=Op)
OpVarCov = TypeVar("OpVarCov", bound=Op, covariant=True)


class ParentBuilder(ToNode, Protocol[OpVar]):
Expand All @@ -85,7 +85,7 @@ def parent_op(self) -> OpVar:


@dataclass()
class Hugr(Mapping[Node, NodeData], Generic[OpVar]):
class Hugr(Mapping[Node, NodeData], Generic[OpVarCov]):
"""The core HUGR datastructure.
Args:
Expand All @@ -108,7 +108,7 @@ class Hugr(Mapping[Node, NodeData], Generic[OpVar]):
# List of free node indices, populated when nodes are deleted.
_free_nodes: list[Node]

def __init__(self, root_op: OpVar | None = None) -> None:
def __init__(self, root_op: OpVarCov | None = None) -> None:
self._free_nodes = []
self._links = BiMap()
self._nodes = []
Expand All @@ -134,7 +134,7 @@ def __iter__(self) -> Iterator[Node]:
def __len__(self) -> int:
return self.num_nodes()

def _get_typed_op(self, node: ToNode, cl: type[OpVar2]) -> OpVar2:
def _get_typed_op(self, node: ToNode, cl: type[OpVar]) -> OpVar:
op = self[node].op
assert isinstance(op, cl)
return op
Expand Down Expand Up @@ -329,15 +329,15 @@ def delete_link(self, src: OutPort, dst: InPort) -> None:
return
# TODO make sure sub-offset is handled correctly

def root_op(self) -> OpVar:
def root_op(self) -> OpVarCov:
"""The operation of the root node.
Examples:
>>> h = Hugr()
>>> h.root_op()
Module()
"""
return cast(OpVar, self[self.root].op)
return cast(OpVarCov, self[self.root].op)

def num_nodes(self) -> int:
"""The number of nodes in the HUGR.
Expand Down

0 comments on commit ea9cca0

Please sign in to comment.