From 881d9c377a83ad7be6be5845eefbe183d06c13dd Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 19 Jun 2024 15:44:30 +0100 Subject: [PATCH 01/12] add_cfg without inputs and make cfg outputs optional --- hugr-py/src/hugr/_cfg.py | 12 +++++------- hugr-py/src/hugr/_dfg.py | 8 +++++--- hugr-py/src/hugr/_ops.py | 13 +++++++++++-- hugr-py/src/hugr/serialization/ops.py | 3 ++- hugr-py/tests/test_cfg.py | 2 +- 5 files changed, 24 insertions(+), 14 deletions(-) diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py index 4277edd10..aadef527a 100644 --- a/hugr-py/src/hugr/_cfg.py +++ b/hugr-py/src/hugr/_cfg.py @@ -1,13 +1,13 @@ from __future__ import annotations -from dataclasses import dataclass, replace +from dataclasses import dataclass import hugr._ops as ops from ._dfg import _DfBase from ._exceptions import NoSiblingAncestor, NotInSameCfg, MismatchedExit from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire -from ._tys import FunctionType, TypeRow, Type +from ._tys import TypeRow, Type import hugr._val as val @@ -47,7 +47,7 @@ class Cfg(ParentBuilder[ops.CFG]): exit: Node def __init__(self, input_types: TypeRow) -> None: - root_op = ops.CFG(FunctionType(input=input_types, output=[])) + root_op = ops.CFG(inputs=input_types) hugr = Hugr(root_op) self._init_impl(hugr, hugr.root, input_types) @@ -68,7 +68,7 @@ def new_nested( ) -> Cfg: new = cls.__new__(cls) root = hugr.add_node( - ops.CFG(FunctionType(input=input_types, output=[])), + ops.CFG(inputs=input_types), parent or hugr.root, ) new._init_impl(hugr, root, input_types) @@ -125,6 +125,4 @@ def branch_exit(self, src: Wire) -> None: raise MismatchedExit(src.node.idx) else: self._exit_op._cfg_outputs = out_types - self.parent_op.signature = replace( - self.parent_op.signature, output=out_types - ) + self.parent_op._outputs = out_types diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index 8d4d1447f..b3f9a8468 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -92,11 +92,12 @@ def add_nested( def add_cfg( self, - input_types: TypeRow, *args: Wire, ) -> Cfg: from ._cfg import Cfg + input_types = [self._get_dataflow_type(w) for w in args] + cfg = Cfg.new_nested(input_types, self.hugr, self.parent_node) self._wire_up(cfg.parent_node, args) return cfg @@ -129,10 +130,11 @@ def load_const(self, const_node: ToNode) -> Node: def add_load_const(self, val: val.Value) -> Node: return self.load_const(self.add_const(val)) - def _wire_up(self, node: Node, ports: Iterable[Wire]): + def _wire_up(self, node: Node, ports: Iterable[Wire]) -> TypeRow: tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)] if isinstance(op := self.hugr[node].op, ops.PartialOp): op.set_in_types(tys) + return tys def _get_dataflow_type(self, wire: Wire) -> Type: port = wire.out_port() @@ -141,7 +143,7 @@ def _get_dataflow_type(self, wire: Wire) -> Type: raise ValueError(f"Port {port} is not a dataflow port.") return ty - def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type: + def _wire_up_port(self, node: Node, offset: int, p: Wire): src = p.out_port() node_ancestor = _ancestral_sibling(self.hugr, src.node, node) if node_ancestor is None: diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index 89c6d285c..d38d3fe30 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -259,11 +259,20 @@ def _inputs(self) -> tys.TypeRow: @dataclass() class CFG(DataflowOp): - signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) + inputs: tys.TypeRow + _outputs: tys.TypeRow | None = None + + @property + def outputs(self) -> tys.TypeRow: + return _check_complete(self._outputs) + + @property + def signature(self) -> tys.FunctionType: + return tys.FunctionType(self.inputs, self.outputs) @property def num_out(self) -> int | None: - return len(self.signature.output) + return len(self.outputs) def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CFG: return sops.CFG( diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index bb4ee2882..24af6f9e2 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -426,7 +426,8 @@ def insert_port_types(self, inputs: TypeRow, outputs: TypeRow) -> None: ) def deserialize(self) -> _ops.CFG: - return _ops.CFG(self.signature.deserialize()) + sig = self.signature.deserialize() + return _ops.CFG(inputs=sig.input, _outputs=sig.output) ControlFlowOp = Conditional | TailLoop | CFG diff --git a/hugr-py/tests/test_cfg.py b/hugr-py/tests/test_cfg.py index a5beed4b9..7a40edeee 100644 --- a/hugr-py/tests/test_cfg.py +++ b/hugr-py/tests/test_cfg.py @@ -40,7 +40,7 @@ def test_branch() -> None: def test_nested_cfg() -> None: dfg = Dfg(tys.Bool) - cfg = dfg.add_cfg([tys.Bool], *dfg.inputs()) + cfg = dfg.add_cfg(*dfg.inputs()) build_basic_cfg(cfg) dfg.set_outputs(cfg) From d0bc8665add854fc225da37ca21a99153692c02d Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 19 Jun 2024 15:48:00 +0100 Subject: [PATCH 02/12] make dfg consistent with other ops --- hugr-py/src/hugr/_dfg.py | 11 +++++------ hugr-py/src/hugr/_ops.py | 22 ++++++++++------------ hugr-py/src/hugr/serialization/ops.py | 3 ++- 3 files changed, 17 insertions(+), 19 deletions(-) diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index b3f9a8468..ea1d027af 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -83,22 +83,21 @@ def add_nested( ) -> Dfg: from ._dfg import Dfg - input_types = [self._get_dataflow_type(w) for w in args] - - parent_op = ops.DFG(list(input_types)) + parent_op = ops.DFG(self._wire_types(args)) dfg = Dfg.new_nested(parent_op, self.hugr, self.parent_node) self._wire_up(dfg.parent_node, args) return dfg + def _wire_types(self, args: Iterable[Wire]) -> TypeRow: + return [self._get_dataflow_type(w) for w in args] + def add_cfg( self, *args: Wire, ) -> Cfg: from ._cfg import Cfg - input_types = [self._get_dataflow_type(w) for w in args] - - cfg = Cfg.new_nested(input_types, self.hugr, self.parent_node) + cfg = Cfg.new_nested(self._wire_types(args), self.hugr, self.parent_node) self._wire_up(cfg.parent_node, args) return cfg diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index d38d3fe30..97c3b2134 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -219,15 +219,18 @@ def _set_out_types(self, types: tys.TypeRow) -> None: ... def _inputs(self) -> tys.TypeRow: ... -@dataclass() +@dataclass class DFG(DfParentOp, DataflowOp): - _signature: tys.TypeRow | tys.FunctionType + inputs: tys.TypeRow + _outputs: tys.TypeRow | None = None + + @property + def outputs(self) -> tys.TypeRow: + return _check_complete(self._outputs) @property def signature(self) -> tys.FunctionType: - if isinstance(self._signature, tys.FunctionType): - return self._signature - raise IncompleteOp() + return tys.FunctionType(self.inputs, self.outputs) @property def num_out(self) -> int | None: @@ -246,15 +249,10 @@ def outer_signature(self) -> tys.FunctionType: return self.signature def _set_out_types(self, types: tys.TypeRow) -> None: - assert isinstance(self._signature, list), "Signature has already been set." - self._signature = tys.FunctionType(self._signature, types) + self._outputs = types def _inputs(self) -> tys.TypeRow: - match self._signature: - case tys.FunctionType(input, _): - return input - case list(_): - return self._signature + return self.inputs @dataclass() diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index 24af6f9e2..05f13959b 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -347,7 +347,8 @@ def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: ) def deserialize(self) -> _ops.DFG: - return _ops.DFG(self.signature.deserialize()) + sig = self.signature.deserialize() + return _ops.DFG(sig.input, sig.output) # ------------------------------------------------ From b1665d945f079a9b397f831c8182062750f9c00c Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 19 Jun 2024 17:07:40 +0100 Subject: [PATCH 03/12] feat(hugr-py): add conditional building also common up some functions and use tys.Sum where applicable --- hugr-py/src/hugr/_cfg.py | 2 + hugr-py/src/hugr/_cond_loop.py | 71 +++++++++++++++++++++ hugr-py/src/hugr/_dfg.py | 27 +++++--- hugr-py/src/hugr/_ops.py | 90 +++++++++++++++++++++++---- hugr-py/src/hugr/_tys.py | 6 ++ hugr-py/src/hugr/serialization/ops.py | 16 ++++- hugr-py/tests/test_cfg.py | 4 +- hugr-py/tests/test_cond_loop.py | 37 +++++++++++ 8 files changed, 230 insertions(+), 23 deletions(-) create mode 100644 hugr-py/src/hugr/_cond_loop.py create mode 100644 hugr-py/tests/test_cond_loop.py diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py index aadef527a..fae40cc9d 100644 --- a/hugr-py/src/hugr/_cfg.py +++ b/hugr-py/src/hugr/_cfg.py @@ -97,6 +97,8 @@ def add_block(self, input_types: TypeRow) -> Block: ) return new_block + # TODO insert_block + def add_successor(self, pred: Wire) -> Block: b = self.add_block(self._nth_outputs(pred)) diff --git a/hugr-py/src/hugr/_cond_loop.py b/hugr-py/src/hugr/_cond_loop.py new file mode 100644 index 000000000..282e4cca8 --- /dev/null +++ b/hugr-py/src/hugr/_cond_loop.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import hugr._ops as ops + +from ._dfg import _DfBase +from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire +from ._tys import Sum, TypeRow + + +class Case(_DfBase[ops.Case]): + _parent: Conditional | None = None + + def set_outputs(self, *outputs: Wire) -> None: + super().set_outputs(*outputs) + if self._parent is not None: + self._parent._update_outputs(self._wire_types(outputs)) + + +@dataclass +class Conditional(ParentBuilder[ops.Conditional]): + hugr: Hugr + parent_node: Node + cases: dict[int, Node | None] + + def __init__(self, sum_ty: Sum, other_inputs: TypeRow) -> None: + root_op = ops.Conditional(sum_ty, other_inputs) + hugr = Hugr(root_op) + self._init_impl(hugr, hugr.root, len(sum_ty.variant_rows)) + + def _init_impl(self: Conditional, hugr: Hugr, root: Node, n_cases: int) -> None: + self.hugr = hugr + self.parent_node = root + self.cases = {i: None for i in range(n_cases)} + + @classmethod + def new_nested( + cls, + sum_ty: Sum, + other_inputs: TypeRow, + hugr: Hugr, + parent: ToNode | None = None, + ) -> Conditional: + new = cls.__new__(cls) + root = hugr.add_node( + ops.Conditional(sum_ty, other_inputs), + parent or hugr.root, + ) + new._init_impl(hugr, root, len(sum_ty.variant_rows)) + return new + + def _update_outputs(self, outputs: TypeRow) -> None: + if self.parent_op._outputs is None: + self.parent_op._outputs = outputs + else: + assert outputs == self.parent_op._outputs, "Mismatched case outputs." + + def add_case(self, case_id: int) -> Case: + assert case_id in self.cases, f"Case {case_id} out of possible range." + input_types = self.parent_op.nth_inputs(case_id) + new_case = Case.new_nested( + ops.Case(input_types), + self.hugr, + self.parent_node, + ) + new_case._parent = self + self.cases[case_id] = new_case.parent_node + return new_case + + # TODO insert_case diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index ea1d027af..bbdaef277 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -11,13 +11,14 @@ import hugr._ops as ops import hugr._val as val -from hugr._tys import Type, TypeRow +from hugr._tys import Type, TypeRow, get_first_sum from ._exceptions import NoSiblingAncestor from ._hugr import Hugr, Node, OutPort, ParentBuilder, ToNode, Wire if TYPE_CHECKING: from ._cfg import Cfg + from ._cond_loop import Conditional DP = TypeVar("DP", bound=ops.DfParentOp) @@ -72,10 +73,13 @@ def add_op(self, op: ops.DataflowOp, /, *args: Wire) -> Node: def add(self, com: ops.Command) -> Node: return self.add_op(com.op, *com.incoming) + def _insert_nested_impl(self, builder: ParentBuilder, *args: Wire) -> Node: + mapping = self.hugr.insert_hugr(builder.hugr, self.parent_node) + self._wire_up(mapping[builder.parent_node], args) + return mapping[builder.parent_node] + def insert_nested(self, dfg: Dfg, *args: Wire) -> Node: - mapping = self.hugr.insert_hugr(dfg.hugr, self.parent_node) - self._wire_up(mapping[dfg.parent_node], args) - return mapping[dfg.parent_node] + return self._insert_nested_impl(dfg, *args) def add_nested( self, @@ -102,9 +106,18 @@ def add_cfg( return cfg def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node: - mapping = self.hugr.insert_hugr(cfg.hugr, self.parent_node) - self._wire_up(mapping[cfg.parent_node], args) - return mapping[cfg.parent_node] + return self._insert_nested_impl(cfg, *args) + + def add_conditional(self, *args: Wire) -> Conditional: + from ._cond_loop import Conditional + + (sum_, other_inputs) = get_first_sum(self._wire_types(args)) + cond = Conditional.new_nested(sum_, other_inputs, self.hugr, self.parent_node) + self._wire_up(cond.parent_node, args) + return cond + + def insert_conditional(self, cond: Conditional, *args: Wire) -> Node: + return self._insert_nested_impl(cond, *args) def set_outputs(self, *args: Wire) -> None: self._wire_up(self.output_node, args) diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index 97c3b2134..800547692 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -192,19 +192,19 @@ def set_in_types(self, types: tys.TypeRow) -> None: @dataclass() class Tag(DataflowOp): tag: int - variants: list[tys.TypeRow] + sum_ty: tys.Sum num_out: int | None = 1 def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Tag: return sops.Tag( parent=parent.idx, tag=self.tag, - variants=[ser_it(r) for r in self.variants], + variants=[ser_it(r) for r in self.sum_ty.variant_rows], ) def outer_signature(self) -> tys.FunctionType: return tys.FunctionType( - input=self.variants[self.tag], output=[tys.Sum(self.variants)] + input=self.sum_ty.variant_rows[self.tag], output=[self.sum_ty] ) def __call__(self, value: Wire) -> Command: @@ -285,13 +285,13 @@ def outer_signature(self) -> tys.FunctionType: @dataclass class DataflowBlock(DfParentOp): inputs: tys.TypeRow - _sum_rows: list[tys.TypeRow] | None = None + _sum: tys.Sum | None = None _other_outputs: tys.TypeRow | None = None extension_delta: tys.ExtensionSet = field(default_factory=list) @property - def sum_rows(self) -> list[tys.TypeRow]: - return _check_complete(self._sum_rows) + def sum_ty(self) -> tys.Sum: + return _check_complete(self._sum) @property def other_outputs(self) -> tys.TypeRow: @@ -299,36 +299,35 @@ def other_outputs(self) -> tys.TypeRow: @property def num_out(self) -> int | None: - return len(self.sum_rows) + return len(self.sum_ty.variant_rows) def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DataflowBlock: return sops.DataflowBlock( parent=parent.idx, inputs=ser_it(self.inputs), - sum_rows=list(map(ser_it, self.sum_rows)), + sum_rows=list(map(ser_it, self.sum_ty.variant_rows)), other_outputs=ser_it(self.other_outputs), extension_delta=self.extension_delta, ) def inner_signature(self) -> tys.FunctionType: return tys.FunctionType( - input=self.inputs, output=[tys.Sum(self.sum_rows), *self.other_outputs] + input=self.inputs, output=[self.sum_ty, *self.other_outputs] ) def port_kind(self, port: InPort | OutPort) -> tys.Kind: return tys.CFKind() def _set_out_types(self, types: tys.TypeRow) -> None: - (sum_, *other) = types - assert isinstance(sum_, tys.Sum), f"Expected Sum, got {sum_}" - self._sum_rows = sum_.variant_rows + (sum_, other) = tys.get_first_sum(types) + self._sum = sum_ self._other_outputs = other def _inputs(self) -> tys.TypeRow: return self.inputs def nth_outputs(self, n: int) -> tys.TypeRow: - return [*self.sum_rows[n], *self.other_outputs] + return [*self.sum_ty.variant_rows[n], *self.other_outputs] @dataclass @@ -380,3 +379,68 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.LoadConstant: def outer_signature(self) -> tys.FunctionType: return tys.FunctionType(input=[], output=[self.type_()]) + + +@dataclass() +class Conditional(DataflowOp): + sum_ty: tys.Sum + other_inputs: tys.TypeRow + _outputs: tys.TypeRow | None = None + + @property + def outputs(self) -> tys.TypeRow: + return _check_complete(self._outputs) + + @property + def signature(self) -> tys.FunctionType: + inputs = [self.sum_ty, *self.other_inputs] + return tys.FunctionType(inputs, self.outputs) + + @property + def num_out(self) -> int | None: + return len(self.outputs) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Conditional: + return sops.Conditional( + parent=parent.idx, + sum_rows=[ser_it(r) for r in self.sum_ty.variant_rows], + other_inputs=ser_it(self.other_inputs), + outputs=ser_it(self.outputs), + ) + + def outer_signature(self) -> tys.FunctionType: + return self.signature + + def nth_inputs(self, n: int) -> tys.TypeRow: + return [*self.sum_ty.variant_rows[n], *self.other_inputs] + + +@dataclass +class Case(DfParentOp): + inputs: tys.TypeRow + _outputs: tys.TypeRow | None = None + + @property + def outputs(self) -> tys.TypeRow: + return _check_complete(self._outputs) + + @property + def num_out(self) -> int | None: + return 0 + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Case: + return sops.Case( + parent=parent.idx, signature=self.inner_signature().to_serial() + ) + + def inner_signature(self) -> tys.FunctionType: + return tys.FunctionType(self.inputs, self.outputs) + + def port_kind(self, port: InPort | OutPort) -> tys.Kind: + raise NotImplementedError("Case nodes have no external ports.") + + def _set_out_types(self, types: tys.TypeRow) -> None: + self._outputs = types + + def _inputs(self) -> tys.TypeRow: + return self.inputs diff --git a/hugr-py/src/hugr/_tys.py b/hugr-py/src/hugr/_tys.py index 7b0fa2335..bb7f609eb 100644 --- a/hugr-py/src/hugr/_tys.py +++ b/hugr-py/src/hugr/_tys.py @@ -303,3 +303,9 @@ class OrderKind: ... Kind = ValueKind | ConstKind | FunctionKind | CFKind | OrderKind + + +def get_first_sum(types: TypeRow) -> tuple[Sum, TypeRow]: + (sum_, *other) = types + assert isinstance(sum_, Sum), f"Expected Sum, got {sum_}" + return sum_, other diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index 05f13959b..54fe9b373 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -206,7 +206,7 @@ def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: def deserialize(self) -> _ops.DataflowBlock: return _ops.DataflowBlock( inputs=deser_it(self.inputs), - _sum_rows=[deser_it(r) for r in self.sum_rows], + _sum=_tys.Sum([deser_it(r) for r in self.sum_rows]), _other_outputs=deser_it(self.other_outputs), ) @@ -384,6 +384,13 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: self.other_inputs = list(in_types[1:]) self.outputs = list(out_types) + def deserialize(self) -> _ops.Conditional: + return _ops.Conditional( + _tys.Sum([deser_it(r) for r in self.sum_rows]), + deser_it(self.other_inputs), + deser_it(self.outputs), + ) + class Case(BaseOp): """Case ops - nodes valid inside Conditional nodes.""" @@ -397,6 +404,10 @@ def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: input=list(inputs), output=list(outputs), extension_reqs=ExtensionSet([]) ) + def deserialize(self) -> _ops.Case: + sig = self.signature.deserialize() + return _ops.Case(inputs=sig.input, _outputs=sig.output) + class TailLoop(DataflowOp): """Tail-controlled loop.""" @@ -522,7 +533,7 @@ class Tag(DataflowOp): def deserialize(self) -> _ops.Tag: return _ops.Tag( tag=self.tag, - variants=[deser_it(v) for v in self.variants], + sum_ty=_tys.Sum([deser_it(v) for v in self.variants]), ) @@ -616,3 +627,4 @@ class OpDef(ConfiguredBaseModel, populate_by_name=True): # needed to avoid circular imports from hugr import _ops # noqa: E402 from hugr import _val # noqa: E402 +from hugr import _tys # noqa: E402 diff --git a/hugr-py/tests/test_cfg.py b/hugr-py/tests/test_cfg.py index 7a40edeee..daeb72220 100644 --- a/hugr-py/tests/test_cfg.py +++ b/hugr-py/tests/test_cfg.py @@ -73,7 +73,9 @@ def test_asymm_types() -> None: entry = cfg.add_entry() int_load = entry.add_load_const(IntVal(34)) - tagged_int = entry.add(ops.Tag(0, [[INT_T], [tys.Bool]])(int_load)) + + sum_ty = tys.Sum([[INT_T], [tys.Bool]]) + tagged_int = entry.add(ops.Tag(0, sum_ty)(int_load)) entry.set_block_outputs(tagged_int) middle = cfg.add_successor(entry[0]) diff --git a/hugr-py/tests/test_cond_loop.py b/hugr-py/tests/test_cond_loop.py new file mode 100644 index 000000000..aa1ec3801 --- /dev/null +++ b/hugr-py/tests/test_cond_loop.py @@ -0,0 +1,37 @@ +from hugr._cond_loop import Conditional +from hugr._dfg import Dfg +import hugr._tys as tys +import hugr._ops as ops +import pytest +from .test_hugr_build import INT_T, _validate, IntVal + +SUM_T = tys.Sum([[tys.Qubit], [tys.Qubit, INT_T]]) + + +def build_cond(h: Conditional) -> None: + with pytest.raises(AssertionError): + h.add_case(2) + + case0 = h.add_case(0) + q, b = case0.inputs() + case0.set_outputs(q, b) + + case1 = h.add_case(1) + q, _i, b = case1.inputs() + case1.set_outputs(q, b) + + +def test_cond() -> None: + h = Conditional(SUM_T, [tys.Bool]) + build_cond(h) + _validate(h.hugr) + + +def test_nested_cond() -> None: + h = Dfg(tys.Qubit) + (q,) = h.inputs() + tagged_q = h.add(ops.Tag(0, SUM_T)(q)) + cond = h.add_conditional(tagged_q, h.add_load_const(IntVal(1))) + build_cond(cond) + h.set_outputs(*cond[:2]) + _validate(h.hugr) From 532a82bcd2cb30b0f2d39d8087e77e73dc2579e6 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 20 Jun 2024 14:35:01 +0100 Subject: [PATCH 04/12] make constant loading easier --- hugr-py/src/hugr/_cfg.py | 2 +- hugr-py/src/hugr/_dfg.py | 11 +++++------ hugr-py/tests/test_cfg.py | 4 ++-- hugr-py/tests/test_cond_loop.py | 2 +- hugr-py/tests/test_hugr_build.py | 2 +- 5 files changed, 10 insertions(+), 11 deletions(-) diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py index fae40cc9d..a1ed1417c 100644 --- a/hugr-py/src/hugr/_cfg.py +++ b/hugr-py/src/hugr/_cfg.py @@ -16,7 +16,7 @@ def set_block_outputs(self, branching: Wire, *other_outputs: Wire) -> None: self.set_outputs(branching, *other_outputs) def set_single_succ_outputs(self, *outputs: Wire) -> None: - u = self.add_load_const(val.Unit) + u = self.load(val.Unit) self.set_outputs(u, *outputs) def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type: diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index bbdaef277..fb60382e4 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -130,18 +130,17 @@ def add_state_order(self, src: Node, dst: Node) -> None: def add_const(self, val: val.Value) -> Node: return self.hugr.add_const(val, self.parent_node) - def load_const(self, const_node: ToNode) -> Node: - const_op = self.hugr._get_typed_op(const_node, ops.Const) + def load(self, const: ToNode | val.Value) -> Node: + if isinstance(const, val.Value): + const = self.add_const(const) + const_op = self.hugr._get_typed_op(const, ops.Const) load_op = ops.LoadConst(const_op.val.type_()) load = self.add(load_op()) - self.hugr.add_link(const_node.out_port(), load.inp(0)) + self.hugr.add_link(const.out_port(), load.inp(0)) return load - def add_load_const(self, val: val.Value) -> Node: - return self.load_const(self.add_const(val)) - def _wire_up(self, node: Node, ports: Iterable[Wire]) -> TypeRow: tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)] if isinstance(op := self.hugr[node].op, ops.PartialOp): diff --git a/hugr-py/tests/test_cfg.py b/hugr-py/tests/test_cfg.py index daeb72220..269c48285 100644 --- a/hugr-py/tests/test_cfg.py +++ b/hugr-py/tests/test_cfg.py @@ -72,7 +72,7 @@ def test_asymm_types() -> None: cfg = Cfg([]) entry = cfg.add_entry() - int_load = entry.add_load_const(IntVal(34)) + int_load = entry.load(IntVal(34)) sum_ty = tys.Sum([[INT_T], [tys.Bool]]) tagged_int = entry.add(ops.Tag(0, sum_ty)(int_load)) @@ -80,7 +80,7 @@ def test_asymm_types() -> None: middle = cfg.add_successor(entry[0]) # discard the int and return the bool from entry - middle.set_single_succ_outputs(middle.add_load_const(val.TRUE)) + middle.set_single_succ_outputs(middle.load(val.TRUE)) # middle expects an int and exit expects a bool cfg.branch_exit(entry[1]) diff --git a/hugr-py/tests/test_cond_loop.py b/hugr-py/tests/test_cond_loop.py index aa1ec3801..459c5da72 100644 --- a/hugr-py/tests/test_cond_loop.py +++ b/hugr-py/tests/test_cond_loop.py @@ -31,7 +31,7 @@ def test_nested_cond() -> None: h = Dfg(tys.Qubit) (q,) = h.inputs() tagged_q = h.add(ops.Tag(0, SUM_T)(q)) - cond = h.add_conditional(tagged_q, h.add_load_const(IntVal(1))) + cond = h.add_conditional(tagged_q, h.load(IntVal(1))) build_cond(cond) h.set_outputs(*cond[:2]) _validate(h.hugr) diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 04d908f64..6ad264efd 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -281,6 +281,6 @@ def test_ancestral_sibling(): ) def test_vals(val: val.Value): d = Dfg() - d.set_outputs(d.add_load_const(val)) + d.set_outputs(d.load(val)) _validate(d.hugr) From d0f1ec52e7c79ea626c9f181b08e11e1340a3e64 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 20 Jun 2024 14:35:35 +0100 Subject: [PATCH 05/12] endo constructor for functiontype --- hugr-py/src/hugr/_tys.py | 4 ++++ hugr-py/tests/test_hugr_build.py | 4 +--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/hugr-py/src/hugr/_tys.py b/hugr-py/src/hugr/_tys.py index bb7f609eb..48dd43b08 100644 --- a/hugr-py/src/hugr/_tys.py +++ b/hugr-py/src/hugr/_tys.py @@ -237,6 +237,10 @@ def to_serial(self) -> stys.FunctionType: def empty(cls) -> FunctionType: return cls(input=[], output=[]) + @classmethod + def endo(cls, tys: TypeRow) -> FunctionType: + return cls(input=tys, output=tys) + def flip(self) -> FunctionType: return FunctionType(input=list(self.output), output=list(self.input)) diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 6ad264efd..03ec92eed 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -44,9 +44,7 @@ class LogicOps(Custom): class NotDef(LogicOps): num_out: int | None = 1 op_name: str = "Not" - signature: tys.FunctionType = field( - default_factory=lambda: tys.FunctionType(input=[tys.Bool], output=[tys.Bool]) - ) + signature: tys.FunctionType = tys.FunctionType.endo([tys.Bool]) def __call__(self, a: Wire) -> Command: return super().__call__(a) From e2d66442110404224edda4c3b329cda01bc285b6 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 20 Jun 2024 15:19:33 +0100 Subject: [PATCH 06/12] feat: convenient constructors for if-else workflows --- hugr-py/src/hugr/_cond_loop.py | 32 ++++++++++++++++++++++++++++---- hugr-py/src/hugr/_dfg.py | 11 +++++++++-- hugr-py/tests/test_cond_loop.py | 20 +++++++++++++++++++- hugr-py/tests/test_hugr_build.py | 18 ++++++++++++++++++ 4 files changed, 74 insertions(+), 7 deletions(-) diff --git a/hugr-py/src/hugr/_cond_loop.py b/hugr-py/src/hugr/_cond_loop.py index 282e4cca8..853b6311a 100644 --- a/hugr-py/src/hugr/_cond_loop.py +++ b/hugr-py/src/hugr/_cond_loop.py @@ -10,12 +10,36 @@ class Case(_DfBase[ops.Case]): - _parent: Conditional | None = None + _parent_cond: Conditional | None = None def set_outputs(self, *outputs: Wire) -> None: super().set_outputs(*outputs) - if self._parent is not None: - self._parent._update_outputs(self._wire_types(outputs)) + if self._parent_cond is not None: + self._parent_cond._update_outputs(self._wire_types(outputs)) + + +@dataclass +class _IfElse(Case): + def __init__(self, case: Case) -> None: + self.hugr = case.hugr + self.parent_node = case.parent_node + self.input_node = case.input_node + self.output_node = case.output_node + self._parent_cond = case._parent_cond + + def _parent_conditional(self) -> Conditional: + assert self._parent_cond is not None, "If must have a parent conditional." + return self._parent_cond + + +class If(_IfElse): + def add_else(self) -> Else: + return Else(self._parent_conditional().add_case(0)) + + +class Else(_IfElse): + def finish(self) -> Node: + return self._parent_conditional().parent_node @dataclass @@ -64,7 +88,7 @@ def add_case(self, case_id: int) -> Case: self.hugr, self.parent_node, ) - new_case._parent = self + new_case._parent_cond = self self.cases[case_id] = new_case.parent_node return new_case diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index fb60382e4..72cf28f8b 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -18,7 +18,7 @@ if TYPE_CHECKING: from ._cfg import Cfg - from ._cond_loop import Conditional + from ._cond_loop import Conditional, If DP = TypeVar("DP", bound=ops.DfParentOp) @@ -108,9 +108,10 @@ def add_cfg( def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node: return self._insert_nested_impl(cfg, *args) - def add_conditional(self, *args: Wire) -> Conditional: + def add_conditional(self, cond: Wire, *args: Wire) -> Conditional: from ._cond_loop import Conditional + args = (cond, *args) (sum_, other_inputs) = get_first_sum(self._wire_types(args)) cond = Conditional.new_nested(sum_, other_inputs, self.hugr, self.parent_node) self._wire_up(cond.parent_node, args) @@ -119,6 +120,12 @@ def add_conditional(self, *args: Wire) -> Conditional: def insert_conditional(self, cond: Conditional, *args: Wire) -> Node: return self._insert_nested_impl(cond, *args) + def add_if(self, cond: Wire, *args: Wire) -> If: + from ._cond_loop import If + + conditional = self.add_conditional(cond, *args) + return If(conditional.add_case(1)) + def set_outputs(self, *args: Wire) -> None: self._wire_up(self.output_node, args) self.parent_op._set_out_types(self._output_op().types) diff --git a/hugr-py/tests/test_cond_loop.py b/hugr-py/tests/test_cond_loop.py index 459c5da72..20ba43b8c 100644 --- a/hugr-py/tests/test_cond_loop.py +++ b/hugr-py/tests/test_cond_loop.py @@ -2,8 +2,9 @@ from hugr._dfg import Dfg import hugr._tys as tys import hugr._ops as ops +import hugr._val as val import pytest -from .test_hugr_build import INT_T, _validate, IntVal +from .test_hugr_build import INT_T, _validate, IntVal, H SUM_T = tys.Sum([[tys.Qubit], [tys.Qubit, INT_T]]) @@ -35,3 +36,20 @@ def test_nested_cond() -> None: build_cond(cond) h.set_outputs(*cond[:2]) _validate(h.hugr) + + +def test_if_else() -> None: + # apply an H if a bool is true. + h = Dfg(tys.Qubit) + (q,) = h.inputs() + if_ = h.add_if(h.load(val.TRUE), q) + + if_.set_outputs(if_.add(H(if_.input_node[0]))) + + else_ = if_.add_else() + else_.set_outputs(else_.input_node[0]) + + cond = else_.finish() + h.set_outputs(cond) + + _validate(h.hugr, True) diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 03ec92eed..27ec5cc95 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -53,6 +53,24 @@ def __call__(self, a: Wire) -> Command: Not = NotDef() +@dataclass +class OneQbGate(Custom): + num_out: int | None = 1 + extension: tys.ExtensionId = "tket2.quantum" + signature: tys.FunctionType = tys.FunctionType.endo([tys.Qubit]) + + def __call__(self, q: Wire) -> Command: + return super().__call__(q) + + +@dataclass +class HDef(OneQbGate): + op_name: str = "H" + + +H = HDef() + + @dataclass class IntOps(Custom): extension: tys.ExtensionId = "arithmetic.int" From 9863375a65cff69c30c0485d77bf8f1f4b205dd6 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 20 Jun 2024 15:58:38 +0100 Subject: [PATCH 07/12] feat: add tail loop building --- hugr-py/src/hugr/_cond_loop.py | 8 ++++-- hugr-py/src/hugr/_dfg.py | 19 +++++++++++++- hugr-py/src/hugr/_ops.py | 44 ++++++++++++++++++++++++++++++++ hugr-py/tests/test_cond_loop.py | 16 +++++++++++- hugr-py/tests/test_hugr_build.py | 23 +++++++++++++---- 5 files changed, 101 insertions(+), 9 deletions(-) diff --git a/hugr-py/src/hugr/_cond_loop.py b/hugr-py/src/hugr/_cond_loop.py index 853b6311a..f7d81bc77 100644 --- a/hugr-py/src/hugr/_cond_loop.py +++ b/hugr-py/src/hugr/_cond_loop.py @@ -44,8 +44,6 @@ def finish(self) -> Node: @dataclass class Conditional(ParentBuilder[ops.Conditional]): - hugr: Hugr - parent_node: Node cases: dict[int, Node | None] def __init__(self, sum_ty: Sum, other_inputs: TypeRow) -> None: @@ -93,3 +91,9 @@ def add_case(self, case_id: int) -> Case: return new_case # TODO insert_case + + +@dataclass +class TailLoop(_DfBase[ops.TailLoop]): + def set_loop_outputs(self, sum_wire: Wire, *rest: Wire) -> None: + self.set_outputs(sum_wire, *rest) diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index 72cf28f8b..ac5ec7f14 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -4,6 +4,7 @@ from typing import ( TYPE_CHECKING, Iterable, + Sequence, TypeVar, ) @@ -18,7 +19,7 @@ if TYPE_CHECKING: from ._cfg import Cfg - from ._cond_loop import Conditional, If + from ._cond_loop import Conditional, If, TailLoop DP = TypeVar("DP", bound=ops.DfParentOp) @@ -126,6 +127,22 @@ def add_if(self, cond: Wire, *args: Wire) -> If: conditional = self.add_conditional(cond, *args) return If(conditional.add_case(1)) + def add_tail_loop( + self, just_inputs: Sequence[Wire], rest: Sequence[Wire] + ) -> TailLoop: + from ._cond_loop import TailLoop + + rest = rest or [] + just_input_types = self._wire_types(just_inputs) + rest_types = self._wire_types(rest) + parent_op = ops.TailLoop(just_input_types, rest_types) + tl = TailLoop.new_nested(parent_op, self.hugr, self.parent_node) + self._wire_up(tl.parent_node, (*just_inputs, *rest)) + return tl + + def insert_tail_loop(self, tl: TailLoop, *args: Wire) -> Node: + return self._insert_nested_impl(tl, *args) + def set_outputs(self, *args: Wire) -> None: self._wire_up(self.output_node, args) self.parent_op._set_out_types(self._output_op().types) diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index 800547692..37219ca6b 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -444,3 +444,47 @@ def _set_out_types(self, types: tys.TypeRow) -> None: def _inputs(self) -> tys.TypeRow: return self.inputs + + +@dataclass +class TailLoop(DfParentOp, DataflowOp): + just_inputs: tys.TypeRow + rest: tys.TypeRow + _just_outputs: tys.TypeRow | None = None + extension_delta: tys.ExtensionSet = field(default_factory=list) + + @property + def just_outputs(self) -> tys.TypeRow: + return _check_complete(self._just_outputs) + + @property + def num_out(self) -> int | None: + return len(self.just_outputs) + len(self.rest) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.TailLoop: + return sops.TailLoop( + parent=parent.idx, + just_inputs=ser_it(self.just_inputs), + just_outputs=ser_it(self.just_outputs), + rest=ser_it(self.rest), + extension_delta=self.extension_delta, + ) + + def inner_signature(self) -> tys.FunctionType: + return tys.FunctionType( + self._inputs(), [tys.Sum([self.just_inputs, self.just_outputs]), *self.rest] + ) + + def outer_signature(self) -> tys.FunctionType: + return tys.FunctionType(self._inputs(), self.just_outputs + self.rest) + + def _set_out_types(self, types: tys.TypeRow) -> None: + (sum_, other) = tys.get_first_sum(types) + just_ins, just_outs = sum_.variant_rows + assert ( + just_ins == self.just_inputs + ), "First sum variant rows don't match TailLoop inputs." + self._just_outputs = just_outs + + def _inputs(self) -> tys.TypeRow: + return self.just_inputs + self.rest diff --git a/hugr-py/tests/test_cond_loop.py b/hugr-py/tests/test_cond_loop.py index 20ba43b8c..9882e2a49 100644 --- a/hugr-py/tests/test_cond_loop.py +++ b/hugr-py/tests/test_cond_loop.py @@ -4,7 +4,7 @@ import hugr._ops as ops import hugr._val as val import pytest -from .test_hugr_build import INT_T, _validate, IntVal, H +from .test_hugr_build import INT_T, _validate, IntVal, H, Measure SUM_T = tys.Sum([[tys.Qubit], [tys.Qubit, INT_T]]) @@ -53,3 +53,17 @@ def test_if_else() -> None: h.set_outputs(cond) _validate(h.hugr, True) + + +def test_tail_loop() -> None: + # apply H while measure is true + + h = Dfg(tys.Qubit) + (q,) = h.inputs() + + tl = h.add_tail_loop([], [q]) + q, b = tl.add(Measure(tl.add(H(tl.input_node[0]))))[:] + + tl.set_loop_outputs(b, q) + + h.set_outputs(tl) diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 27ec5cc95..14a44d12b 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -54,21 +54,34 @@ def __call__(self, a: Wire) -> Command: @dataclass -class OneQbGate(Custom): - num_out: int | None = 1 +class QuantumOps(Custom): extension: tys.ExtensionId = "tket2.quantum" + + +@dataclass +class OneQbGate(QuantumOps): + op_name: str + num_out: int | None = 1 signature: tys.FunctionType = tys.FunctionType.endo([tys.Qubit]) def __call__(self, q: Wire) -> Command: return super().__call__(q) +H = OneQbGate("H") + + @dataclass -class HDef(OneQbGate): - op_name: str = "H" +class MeasureDef(QuantumOps): + op_name: str = "Measure" + num_out: int | None = 2 + signature: tys.FunctionType = tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool]) + + def __call__(self, q: Wire) -> Command: + return super().__call__(q) -H = HDef() +Measure = MeasureDef() @dataclass From 7c5656054b501044942576a61fd8b80901c5cf13 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 20 Jun 2024 16:16:40 +0100 Subject: [PATCH 08/12] add complex loop test --- hugr-py/src/hugr/_ops.py | 3 --- hugr-py/src/hugr/serialization/ops.py | 8 ++++++ hugr-py/tests/test_cond_loop.py | 35 ++++++++++++++++++++++++++- 3 files changed, 42 insertions(+), 4 deletions(-) diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index 37219ca6b..2241eff1d 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -207,9 +207,6 @@ def outer_signature(self) -> tys.FunctionType: input=self.sum_ty.variant_rows[self.tag], output=[self.sum_ty] ) - def __call__(self, value: Wire) -> Command: - return super().__call__(value) - class DfParentOp(Op, Protocol): def inner_signature(self) -> tys.FunctionType: ... diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index 54fe9b373..d2fedf027 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -425,6 +425,14 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: # self.just_outputs = list(out_types) self.rest = list(in_types) + def deserialize(self) -> _ops.TailLoop: + return _ops.TailLoop( + just_inputs=deser_it(self.just_inputs), + _just_outputs=deser_it(self.just_outputs), + rest=deser_it(self.rest), + extension_delta=self.extension_delta, + ) + class CFG(DataflowOp): """A dataflow node which is defined by a child CFG.""" diff --git a/hugr-py/tests/test_cond_loop.py b/hugr-py/tests/test_cond_loop.py index 9882e2a49..37a830dca 100644 --- a/hugr-py/tests/test_cond_loop.py +++ b/hugr-py/tests/test_cond_loop.py @@ -52,7 +52,7 @@ def test_if_else() -> None: cond = else_.finish() h.set_outputs(cond) - _validate(h.hugr, True) + _validate(h.hugr) def test_tail_loop() -> None: @@ -67,3 +67,36 @@ def test_tail_loop() -> None: tl.set_loop_outputs(b, q) h.set_outputs(tl) + + _validate(h.hugr) + + +def test_complex_tail_loop() -> None: + h = Dfg(tys.Qubit) + (q,) = h.inputs() + + # loop passes qubit to itself, and a bool as in-out + tl = h.add_tail_loop([q], [h.load(val.TRUE)]) + q, b = tl.inputs() + + # if b is true, return first variant (just qubit) + if_ = tl.add_if(b, q) + (q,) = if_.inputs() + tagged_q = if_.add(ops.Tag(0, SUM_T)(q)) + if_.set_outputs(tagged_q) + + # else return second variant (qubit, int) + else_ = if_.add_else() + (q,) = else_.inputs() + tagged_q_i = else_.add(ops.Tag(1, SUM_T)(q, else_.load(IntVal(1)))) + else_.set_outputs(tagged_q_i) + + # finish with Sum output from if-else, and bool from inputs + tl.set_loop_outputs(else_.finish(), b) + + # loop returns [qubit, int, bool] + h.set_outputs(*tl[:3]) + + _validate(h.hugr, True) + + # TODO rewrite with context managers From 939772c241fb63f902b0acb4f1583bfe3bfdd555 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 20 Jun 2024 16:23:15 +0100 Subject: [PATCH 09/12] revert accidental annotation removal --- hugr-py/src/hugr/_dfg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index ac5ec7f14..bea0c0cee 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -178,7 +178,7 @@ def _get_dataflow_type(self, wire: Wire) -> Type: raise ValueError(f"Port {port} is not a dataflow port.") return ty - def _wire_up_port(self, node: Node, offset: int, p: Wire): + def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type: src = p.out_port() node_ancestor = _ancestral_sibling(self.hugr, src.node, node) if node_ancestor is None: From c8105bbff05d82781780a07fb95aee0a05e5fd75 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 21 Jun 2024 12:05:25 +0100 Subject: [PATCH 10/12] refactor: turn conditional asserts in to exceptions --- hugr-py/src/hugr/_cond_loop.py | 13 ++++++++++--- hugr-py/tests/test_cond_loop.py | 4 ++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/hugr-py/src/hugr/_cond_loop.py b/hugr-py/src/hugr/_cond_loop.py index f7d81bc77..f233823b0 100644 --- a/hugr-py/src/hugr/_cond_loop.py +++ b/hugr-py/src/hugr/_cond_loop.py @@ -18,6 +18,10 @@ def set_outputs(self, *outputs: Wire) -> None: self._parent_cond._update_outputs(self._wire_types(outputs)) +class ConditionalError(Exception): + pass + + @dataclass class _IfElse(Case): def __init__(self, case: Case) -> None: @@ -28,7 +32,8 @@ def __init__(self, case: Case) -> None: self._parent_cond = case._parent_cond def _parent_conditional(self) -> Conditional: - assert self._parent_cond is not None, "If must have a parent conditional." + if self._parent_cond is None: + raise ConditionalError("If must have a parent conditional.") return self._parent_cond @@ -76,10 +81,12 @@ def _update_outputs(self, outputs: TypeRow) -> None: if self.parent_op._outputs is None: self.parent_op._outputs = outputs else: - assert outputs == self.parent_op._outputs, "Mismatched case outputs." + if outputs != self.parent_op._outputs: + raise ConditionalError("Mismatched case outputs.") def add_case(self, case_id: int) -> Case: - assert case_id in self.cases, f"Case {case_id} out of possible range." + if case_id not in self.cases: + raise ConditionalError(f"Case {case_id} out of possible range.") input_types = self.parent_op.nth_inputs(case_id) new_case = Case.new_nested( ops.Case(input_types), diff --git a/hugr-py/tests/test_cond_loop.py b/hugr-py/tests/test_cond_loop.py index 37a830dca..d03df0eeb 100644 --- a/hugr-py/tests/test_cond_loop.py +++ b/hugr-py/tests/test_cond_loop.py @@ -1,4 +1,4 @@ -from hugr._cond_loop import Conditional +from hugr._cond_loop import Conditional, ConditionalError from hugr._dfg import Dfg import hugr._tys as tys import hugr._ops as ops @@ -10,7 +10,7 @@ def build_cond(h: Conditional) -> None: - with pytest.raises(AssertionError): + with pytest.raises(ConditionalError, match="Case 2 out of possible range."): h.add_case(2) case0 = h.add_case(0) From 9166f7f47227fac8c28e949b1e3bc790a7ddf31d Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 21 Jun 2024 12:06:14 +0100 Subject: [PATCH 11/12] remove vestigial line --- hugr-py/src/hugr/_dfg.py | 1 - 1 file changed, 1 deletion(-) diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index bea0c0cee..b47c81993 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -132,7 +132,6 @@ def add_tail_loop( ) -> TailLoop: from ._cond_loop import TailLoop - rest = rest or [] just_input_types = self._wire_types(just_inputs) rest_types = self._wire_types(rest) parent_op = ops.TailLoop(just_input_types, rest_types) From 4b5cd90686ea2502dc9dcbb7636648845d34c476 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 21 Jun 2024 12:20:26 +0100 Subject: [PATCH 12/12] add insertion tests --- hugr-py/src/hugr/_cond_loop.py | 4 ++++ hugr-py/tests/test_cond_loop.py | 32 +++++++++++++++++++++++++++----- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/hugr-py/src/hugr/_cond_loop.py b/hugr-py/src/hugr/_cond_loop.py index f233823b0..ebf7f3f14 100644 --- a/hugr-py/src/hugr/_cond_loop.py +++ b/hugr-py/src/hugr/_cond_loop.py @@ -102,5 +102,9 @@ def add_case(self, case_id: int) -> Case: @dataclass class TailLoop(_DfBase[ops.TailLoop]): + def __init__(self, just_inputs: TypeRow, rest: TypeRow) -> None: + root_op = ops.TailLoop(just_inputs, rest) + super().__init__(root_op) + def set_loop_outputs(self, sum_wire: Wire, *rest: Wire) -> None: self.set_outputs(sum_wire, *rest) diff --git a/hugr-py/tests/test_cond_loop.py b/hugr-py/tests/test_cond_loop.py index d03df0eeb..24381e06c 100644 --- a/hugr-py/tests/test_cond_loop.py +++ b/hugr-py/tests/test_cond_loop.py @@ -1,4 +1,4 @@ -from hugr._cond_loop import Conditional, ConditionalError +from hugr._cond_loop import Conditional, ConditionalError, TailLoop from hugr._dfg import Dfg import hugr._tys as tys import hugr._ops as ops @@ -32,11 +32,22 @@ def test_nested_cond() -> None: h = Dfg(tys.Qubit) (q,) = h.inputs() tagged_q = h.add(ops.Tag(0, SUM_T)(q)) - cond = h.add_conditional(tagged_q, h.load(IntVal(1))) + cond = h.add_conditional(tagged_q, h.load(val.TRUE)) build_cond(cond) h.set_outputs(*cond[:2]) _validate(h.hugr) + # build then insert + con = Conditional(SUM_T, [tys.Bool]) + build_cond(con) + + h = Dfg(tys.Qubit) + (q,) = h.inputs() + tagged_q = h.add(ops.Tag(0, SUM_T)(q)) + cond_n = h.insert_conditional(con, tagged_q, h.load(val.TRUE)) + h.set_outputs(*cond_n[:2]) + _validate(h.hugr) + def test_if_else() -> None: # apply an H if a bool is true. @@ -57,17 +68,28 @@ def test_if_else() -> None: def test_tail_loop() -> None: # apply H while measure is true + def build_tl(tl: TailLoop) -> None: + q, b = tl.add(Measure(tl.add(H(tl.input_node[0]))))[:] + + tl.set_loop_outputs(b, q) h = Dfg(tys.Qubit) (q,) = h.inputs() tl = h.add_tail_loop([], [q]) - q, b = tl.add(Measure(tl.add(H(tl.input_node[0]))))[:] + build_tl(tl) + h.set_outputs(tl) - tl.set_loop_outputs(b, q) + _validate(h.hugr) - h.set_outputs(tl) + # build then insert + tl = TailLoop([], [tys.Qubit]) + build_tl(tl) + h = Dfg(tys.Qubit) + (q,) = h.inputs() + tl_n = h.insert_tail_loop(tl, q) + h.set_outputs(tl_n) _validate(h.hugr)