Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: Make _DfBase and _DefinitionBuilder public #1461

Merged
merged 2 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions hugr-py/src/hugr/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from hugr import ops, tys, val

from .dfg import _DfBase
from .dfg import DfBase
from .exceptions import MismatchedExit, NoSiblingAncestor, NotInSameCfg
from .hugr import Hugr, ParentBuilder

Expand All @@ -19,7 +19,7 @@
from .tys import Type, TypeRow


class Block(_DfBase[ops.DataflowBlock]):
class Block(DfBase[ops.DataflowBlock]):
"""Builder class for a basic block in a HUGR control flow graph."""

def set_outputs(self, *outputs: Wire) -> None:
Expand Down
6 changes: 3 additions & 3 deletions hugr-py/src/hugr/cond_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
from hugr import ops
from hugr.tys import Sum

from .dfg import _DfBase
from .dfg import DfBase
from .hugr import Hugr, ParentBuilder

if TYPE_CHECKING:
from .node_port import Node, ToNode, Wire
from .tys import TypeRow


class Case(_DfBase[ops.Case]):
class Case(DfBase[ops.Case]):
"""Dataflow graph builder for a case in a conditional."""

_parent_cond: Conditional | None = None
Expand Down Expand Up @@ -202,7 +202,7 @@ def add_case(self, case_id: int) -> Case:


@dataclass
class TailLoop(_DfBase[ops.TailLoop]):
class TailLoop(DfBase[ops.TailLoop]):
"""Builder for a tail-controlled loop.
Args:
Expand Down
8 changes: 4 additions & 4 deletions hugr-py/src/hugr/dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


@dataclass()
class _DefinitionBuilder(Generic[OpVar]):
class DefinitionBuilder(Generic[OpVar]):
"""Base class for builders that can define functions, constants, and aliases.
As this class may be a root node, it does not extend `ParentBuilder`.
Expand Down Expand Up @@ -95,7 +95,7 @@ def add_alias_defn(self, name: str, ty: Type, parent: ToNode | None = None) -> N


@dataclass()
class _DfBase(ParentBuilder[DP], _DefinitionBuilder, AbstractContextManager):
class DfBase(ParentBuilder[DP], DefinitionBuilder, AbstractContextManager):
"""Base class for dataflow graph builders.
Args:
Expand Down Expand Up @@ -636,7 +636,7 @@ def _wire_up_port(self, node: Node, offset: PortOffset, p: Wire) -> tys.Type:
return self._get_dataflow_type(src)


class Dfg(_DfBase[ops.DFG]):
class Dfg(DfBase[ops.DFG]):
"""Builder for a simple nested Dataflow graph, with root node of type
:class:`DFG <hugr.ops.DFG>`.
Expand Down Expand Up @@ -672,7 +672,7 @@ def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None:


@dataclass
class Function(_DfBase[ops.FuncDefn]):
class Function(DfBase[ops.FuncDefn]):
"""Build a function definition as a HUGR dataflow graph.
Args:
Expand Down
4 changes: 2 additions & 2 deletions hugr-py/src/hugr/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import TYPE_CHECKING

from . import ops
from .dfg import Function, _DefinitionBuilder
from .dfg import DefinitionBuilder, Function
from .hugr import Hugr

if TYPE_CHECKING:
Expand All @@ -17,7 +17,7 @@


@dataclass
class Module(_DefinitionBuilder[ops.Module]):
class Module(DefinitionBuilder[ops.Module]):
"""Build a top-level HUGR module.
Examples:
Expand Down
12 changes: 6 additions & 6 deletions hugr-py/src/hugr/hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def to_serial(self, node: Node) -> SerialOp:
P = TypeVar("P", InPort, OutPort)
K = TypeVar("K", InPort, OutPort)
OpVar = TypeVar("OpVar", bound=Op)
OpVar2 = TypeVar("OpVar2", bound=Op)
OpVarCov = TypeVar("OpVarCov", bound=Op, covariant=True)


class ParentBuilder(ToNode, Protocol[OpVar]):
Expand All @@ -85,7 +85,7 @@ def parent_op(self) -> OpVar:


@dataclass()
class Hugr(Mapping[Node, NodeData], Generic[OpVar]):
class Hugr(Mapping[Node, NodeData], Generic[OpVarCov]):
"""The core HUGR datastructure.
Args:
Expand All @@ -108,7 +108,7 @@ class Hugr(Mapping[Node, NodeData], Generic[OpVar]):
# List of free node indices, populated when nodes are deleted.
_free_nodes: list[Node]

def __init__(self, root_op: OpVar | None = None) -> None:
def __init__(self, root_op: OpVarCov | None = None) -> None:
self._free_nodes = []
self._links = BiMap()
self._nodes = []
Expand All @@ -134,7 +134,7 @@ def __iter__(self) -> Iterator[Node]:
def __len__(self) -> int:
return self.num_nodes()

def _get_typed_op(self, node: ToNode, cl: type[OpVar2]) -> OpVar2:
def _get_typed_op(self, node: ToNode, cl: type[OpVar]) -> OpVar:
op = self[node].op
assert isinstance(op, cl)
return op
Expand Down Expand Up @@ -329,15 +329,15 @@ def delete_link(self, src: OutPort, dst: InPort) -> None:
return
# TODO make sure sub-offset is handled correctly

def root_op(self) -> OpVar:
def root_op(self) -> OpVarCov:
"""The operation of the root node.
Examples:
>>> h = Hugr()
>>> h.root_op()
Module()
"""
return cast(OpVar, self[self.root].op)
return cast(OpVarCov, self[self.root].op)

def num_nodes(self) -> int:
"""The number of nodes in the HUGR.
Expand Down
Loading