diff --git a/hugr-py/src/hugr/cfg.py b/hugr-py/src/hugr/cfg.py index 7c46c9abd..9ee06d49a 100644 --- a/hugr-py/src/hugr/cfg.py +++ b/hugr-py/src/hugr/cfg.py @@ -10,7 +10,7 @@ from hugr import ops, tys, val -from .dfg import _DfBase +from .dfg import DfBase from .exceptions import MismatchedExit, NoSiblingAncestor, NotInSameCfg from .hugr import Hugr, ParentBuilder @@ -19,7 +19,7 @@ from .tys import Type, TypeRow -class Block(_DfBase[ops.DataflowBlock]): +class Block(DfBase[ops.DataflowBlock]): """Builder class for a basic block in a HUGR control flow graph.""" def set_outputs(self, *outputs: Wire) -> None: diff --git a/hugr-py/src/hugr/cond_loop.py b/hugr-py/src/hugr/cond_loop.py index 9548f2a10..ca286fcc4 100644 --- a/hugr-py/src/hugr/cond_loop.py +++ b/hugr-py/src/hugr/cond_loop.py @@ -13,7 +13,7 @@ from hugr import ops from hugr.tys import Sum -from .dfg import _DfBase +from .dfg import DfBase from .hugr import Hugr, ParentBuilder if TYPE_CHECKING: @@ -21,7 +21,7 @@ from .tys import TypeRow -class Case(_DfBase[ops.Case]): +class Case(DfBase[ops.Case]): """Dataflow graph builder for a case in a conditional.""" _parent_cond: Conditional | None = None @@ -202,7 +202,7 @@ def add_case(self, case_id: int) -> Case: @dataclass -class TailLoop(_DfBase[ops.TailLoop]): +class TailLoop(DfBase[ops.TailLoop]): """Builder for a tail-controlled loop. Args: diff --git a/hugr-py/src/hugr/dfg.py b/hugr-py/src/hugr/dfg.py index 76ca42b2a..a2b4dbb0e 100644 --- a/hugr-py/src/hugr/dfg.py +++ b/hugr-py/src/hugr/dfg.py @@ -30,7 +30,7 @@ @dataclass() -class _DefinitionBuilder(Generic[OpVar]): +class DefinitionBuilder(Generic[OpVar]): """Base class for builders that can define functions, constants, and aliases. As this class may be a root node, it does not extend `ParentBuilder`. @@ -95,7 +95,7 @@ def add_alias_defn(self, name: str, ty: Type, parent: ToNode | None = None) -> N @dataclass() -class _DfBase(ParentBuilder[DP], _DefinitionBuilder, AbstractContextManager): +class DfBase(ParentBuilder[DP], DefinitionBuilder, AbstractContextManager): """Base class for dataflow graph builders. Args: @@ -636,7 +636,7 @@ def _wire_up_port(self, node: Node, offset: PortOffset, p: Wire) -> tys.Type: return self._get_dataflow_type(src) -class Dfg(_DfBase[ops.DFG]): +class Dfg(DfBase[ops.DFG]): """Builder for a simple nested Dataflow graph, with root node of type :class:`DFG `. @@ -672,7 +672,7 @@ def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None: @dataclass -class Function(_DfBase[ops.FuncDefn]): +class Function(DfBase[ops.FuncDefn]): """Build a function definition as a HUGR dataflow graph. Args: diff --git a/hugr-py/src/hugr/function.py b/hugr-py/src/hugr/function.py index e95042655..8e31b50b8 100644 --- a/hugr-py/src/hugr/function.py +++ b/hugr-py/src/hugr/function.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING from . import ops -from .dfg import Function, _DefinitionBuilder +from .dfg import DefinitionBuilder, Function from .hugr import Hugr if TYPE_CHECKING: @@ -17,7 +17,7 @@ @dataclass -class Module(_DefinitionBuilder[ops.Module]): +class Module(DefinitionBuilder[ops.Module]): """Build a top-level HUGR module. Examples: diff --git a/hugr-py/src/hugr/hugr.py b/hugr-py/src/hugr/hugr.py index 2873cb258..7ebf9a6a8 100644 --- a/hugr-py/src/hugr/hugr.py +++ b/hugr-py/src/hugr/hugr.py @@ -64,7 +64,7 @@ def to_serial(self, node: Node) -> SerialOp: P = TypeVar("P", InPort, OutPort) K = TypeVar("K", InPort, OutPort) OpVar = TypeVar("OpVar", bound=Op) -OpVar2 = TypeVar("OpVar2", bound=Op) +OpVarCov = TypeVar("OpVarCov", bound=Op, covariant=True) class ParentBuilder(ToNode, Protocol[OpVar]): @@ -85,7 +85,7 @@ def parent_op(self) -> OpVar: @dataclass() -class Hugr(Mapping[Node, NodeData], Generic[OpVar]): +class Hugr(Mapping[Node, NodeData], Generic[OpVarCov]): """The core HUGR datastructure. Args: @@ -108,7 +108,7 @@ class Hugr(Mapping[Node, NodeData], Generic[OpVar]): # List of free node indices, populated when nodes are deleted. _free_nodes: list[Node] - def __init__(self, root_op: OpVar | None = None) -> None: + def __init__(self, root_op: OpVarCov | None = None) -> None: self._free_nodes = [] self._links = BiMap() self._nodes = [] @@ -134,7 +134,7 @@ def __iter__(self) -> Iterator[Node]: def __len__(self) -> int: return self.num_nodes() - def _get_typed_op(self, node: ToNode, cl: type[OpVar2]) -> OpVar2: + def _get_typed_op(self, node: ToNode, cl: type[OpVar]) -> OpVar: op = self[node].op assert isinstance(op, cl) return op @@ -329,7 +329,7 @@ def delete_link(self, src: OutPort, dst: InPort) -> None: return # TODO make sure sub-offset is handled correctly - def root_op(self) -> OpVar: + def root_op(self) -> OpVarCov: """The operation of the root node. Examples: @@ -337,7 +337,7 @@ def root_op(self) -> OpVar: >>> h.root_op() Module() """ - return cast(OpVar, self[self.root].op) + return cast(OpVarCov, self[self.root].op) def num_nodes(self) -> int: """The number of nodes in the HUGR.