diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py index 4277edd10..a1ed1417c 100644 --- a/hugr-py/src/hugr/_cfg.py +++ b/hugr-py/src/hugr/_cfg.py @@ -1,13 +1,13 @@ from __future__ import annotations -from dataclasses import dataclass, replace +from dataclasses import dataclass import hugr._ops as ops from ._dfg import _DfBase from ._exceptions import NoSiblingAncestor, NotInSameCfg, MismatchedExit from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire -from ._tys import FunctionType, TypeRow, Type +from ._tys import TypeRow, Type import hugr._val as val @@ -16,7 +16,7 @@ def set_block_outputs(self, branching: Wire, *other_outputs: Wire) -> None: self.set_outputs(branching, *other_outputs) def set_single_succ_outputs(self, *outputs: Wire) -> None: - u = self.add_load_const(val.Unit) + u = self.load(val.Unit) self.set_outputs(u, *outputs) def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type: @@ -47,7 +47,7 @@ class Cfg(ParentBuilder[ops.CFG]): exit: Node def __init__(self, input_types: TypeRow) -> None: - root_op = ops.CFG(FunctionType(input=input_types, output=[])) + root_op = ops.CFG(inputs=input_types) hugr = Hugr(root_op) self._init_impl(hugr, hugr.root, input_types) @@ -68,7 +68,7 @@ def new_nested( ) -> Cfg: new = cls.__new__(cls) root = hugr.add_node( - ops.CFG(FunctionType(input=input_types, output=[])), + ops.CFG(inputs=input_types), parent or hugr.root, ) new._init_impl(hugr, root, input_types) @@ -97,6 +97,8 @@ def add_block(self, input_types: TypeRow) -> Block: ) return new_block + # TODO insert_block + def add_successor(self, pred: Wire) -> Block: b = self.add_block(self._nth_outputs(pred)) @@ -125,6 +127,4 @@ def branch_exit(self, src: Wire) -> None: raise MismatchedExit(src.node.idx) else: self._exit_op._cfg_outputs = out_types - self.parent_op.signature = replace( - self.parent_op.signature, output=out_types - ) + self.parent_op._outputs = out_types diff --git a/hugr-py/src/hugr/_cond_loop.py b/hugr-py/src/hugr/_cond_loop.py new file mode 100644 index 000000000..ebf7f3f14 --- /dev/null +++ b/hugr-py/src/hugr/_cond_loop.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import hugr._ops as ops + +from ._dfg import _DfBase +from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire +from ._tys import Sum, TypeRow + + +class Case(_DfBase[ops.Case]): + _parent_cond: Conditional | None = None + + def set_outputs(self, *outputs: Wire) -> None: + super().set_outputs(*outputs) + if self._parent_cond is not None: + self._parent_cond._update_outputs(self._wire_types(outputs)) + + +class ConditionalError(Exception): + pass + + +@dataclass +class _IfElse(Case): + def __init__(self, case: Case) -> None: + self.hugr = case.hugr + self.parent_node = case.parent_node + self.input_node = case.input_node + self.output_node = case.output_node + self._parent_cond = case._parent_cond + + def _parent_conditional(self) -> Conditional: + if self._parent_cond is None: + raise ConditionalError("If must have a parent conditional.") + return self._parent_cond + + +class If(_IfElse): + def add_else(self) -> Else: + return Else(self._parent_conditional().add_case(0)) + + +class Else(_IfElse): + def finish(self) -> Node: + return self._parent_conditional().parent_node + + +@dataclass +class Conditional(ParentBuilder[ops.Conditional]): + cases: dict[int, Node | None] + + def __init__(self, sum_ty: Sum, other_inputs: TypeRow) -> None: + root_op = ops.Conditional(sum_ty, other_inputs) + hugr = Hugr(root_op) + self._init_impl(hugr, hugr.root, len(sum_ty.variant_rows)) + + def _init_impl(self: Conditional, hugr: Hugr, root: Node, n_cases: int) -> None: + self.hugr = hugr + self.parent_node = root + self.cases = {i: None for i in range(n_cases)} + + @classmethod + def new_nested( + cls, + sum_ty: Sum, + other_inputs: TypeRow, + hugr: Hugr, + parent: ToNode | None = None, + ) -> Conditional: + new = cls.__new__(cls) + root = hugr.add_node( + ops.Conditional(sum_ty, other_inputs), + parent or hugr.root, + ) + new._init_impl(hugr, root, len(sum_ty.variant_rows)) + return new + + def _update_outputs(self, outputs: TypeRow) -> None: + if self.parent_op._outputs is None: + self.parent_op._outputs = outputs + else: + if outputs != self.parent_op._outputs: + raise ConditionalError("Mismatched case outputs.") + + def add_case(self, case_id: int) -> Case: + if case_id not in self.cases: + raise ConditionalError(f"Case {case_id} out of possible range.") + input_types = self.parent_op.nth_inputs(case_id) + new_case = Case.new_nested( + ops.Case(input_types), + self.hugr, + self.parent_node, + ) + new_case._parent_cond = self + self.cases[case_id] = new_case.parent_node + return new_case + + # TODO insert_case + + +@dataclass +class TailLoop(_DfBase[ops.TailLoop]): + def __init__(self, just_inputs: TypeRow, rest: TypeRow) -> None: + root_op = ops.TailLoop(just_inputs, rest) + super().__init__(root_op) + + def set_loop_outputs(self, sum_wire: Wire, *rest: Wire) -> None: + self.set_outputs(sum_wire, *rest) diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index 8d4d1447f..b47c81993 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -4,6 +4,7 @@ from typing import ( TYPE_CHECKING, Iterable, + Sequence, TypeVar, ) @@ -11,13 +12,14 @@ import hugr._ops as ops import hugr._val as val -from hugr._tys import Type, TypeRow +from hugr._tys import Type, TypeRow, get_first_sum from ._exceptions import NoSiblingAncestor from ._hugr import Hugr, Node, OutPort, ParentBuilder, ToNode, Wire if TYPE_CHECKING: from ._cfg import Cfg + from ._cond_loop import Conditional, If, TailLoop DP = TypeVar("DP", bound=ops.DfParentOp) @@ -72,10 +74,13 @@ def add_op(self, op: ops.DataflowOp, /, *args: Wire) -> Node: def add(self, com: ops.Command) -> Node: return self.add_op(com.op, *com.incoming) + def _insert_nested_impl(self, builder: ParentBuilder, *args: Wire) -> Node: + mapping = self.hugr.insert_hugr(builder.hugr, self.parent_node) + self._wire_up(mapping[builder.parent_node], args) + return mapping[builder.parent_node] + def insert_nested(self, dfg: Dfg, *args: Wire) -> Node: - mapping = self.hugr.insert_hugr(dfg.hugr, self.parent_node) - self._wire_up(mapping[dfg.parent_node], args) - return mapping[dfg.parent_node] + return self._insert_nested_impl(dfg, *args) def add_nested( self, @@ -83,28 +88,59 @@ def add_nested( ) -> Dfg: from ._dfg import Dfg - input_types = [self._get_dataflow_type(w) for w in args] - - parent_op = ops.DFG(list(input_types)) + parent_op = ops.DFG(self._wire_types(args)) dfg = Dfg.new_nested(parent_op, self.hugr, self.parent_node) self._wire_up(dfg.parent_node, args) return dfg + def _wire_types(self, args: Iterable[Wire]) -> TypeRow: + return [self._get_dataflow_type(w) for w in args] + def add_cfg( self, - input_types: TypeRow, *args: Wire, ) -> Cfg: from ._cfg import Cfg - cfg = Cfg.new_nested(input_types, self.hugr, self.parent_node) + cfg = Cfg.new_nested(self._wire_types(args), self.hugr, self.parent_node) self._wire_up(cfg.parent_node, args) return cfg def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node: - mapping = self.hugr.insert_hugr(cfg.hugr, self.parent_node) - self._wire_up(mapping[cfg.parent_node], args) - return mapping[cfg.parent_node] + return self._insert_nested_impl(cfg, *args) + + def add_conditional(self, cond: Wire, *args: Wire) -> Conditional: + from ._cond_loop import Conditional + + args = (cond, *args) + (sum_, other_inputs) = get_first_sum(self._wire_types(args)) + cond = Conditional.new_nested(sum_, other_inputs, self.hugr, self.parent_node) + self._wire_up(cond.parent_node, args) + return cond + + def insert_conditional(self, cond: Conditional, *args: Wire) -> Node: + return self._insert_nested_impl(cond, *args) + + def add_if(self, cond: Wire, *args: Wire) -> If: + from ._cond_loop import If + + conditional = self.add_conditional(cond, *args) + return If(conditional.add_case(1)) + + def add_tail_loop( + self, just_inputs: Sequence[Wire], rest: Sequence[Wire] + ) -> TailLoop: + from ._cond_loop import TailLoop + + just_input_types = self._wire_types(just_inputs) + rest_types = self._wire_types(rest) + parent_op = ops.TailLoop(just_input_types, rest_types) + tl = TailLoop.new_nested(parent_op, self.hugr, self.parent_node) + self._wire_up(tl.parent_node, (*just_inputs, *rest)) + return tl + + def insert_tail_loop(self, tl: TailLoop, *args: Wire) -> Node: + return self._insert_nested_impl(tl, *args) def set_outputs(self, *args: Wire) -> None: self._wire_up(self.output_node, args) @@ -117,22 +153,22 @@ def add_state_order(self, src: Node, dst: Node) -> None: def add_const(self, val: val.Value) -> Node: return self.hugr.add_const(val, self.parent_node) - def load_const(self, const_node: ToNode) -> Node: - const_op = self.hugr._get_typed_op(const_node, ops.Const) + def load(self, const: ToNode | val.Value) -> Node: + if isinstance(const, val.Value): + const = self.add_const(const) + const_op = self.hugr._get_typed_op(const, ops.Const) load_op = ops.LoadConst(const_op.val.type_()) load = self.add(load_op()) - self.hugr.add_link(const_node.out_port(), load.inp(0)) + self.hugr.add_link(const.out_port(), load.inp(0)) return load - def add_load_const(self, val: val.Value) -> Node: - return self.load_const(self.add_const(val)) - - def _wire_up(self, node: Node, ports: Iterable[Wire]): + def _wire_up(self, node: Node, ports: Iterable[Wire]) -> TypeRow: tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)] if isinstance(op := self.hugr[node].op, ops.PartialOp): op.set_in_types(tys) + return tys def _get_dataflow_type(self, wire: Wire) -> Type: port = wire.out_port() diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index 89c6d285c..2241eff1d 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -192,24 +192,21 @@ def set_in_types(self, types: tys.TypeRow) -> None: @dataclass() class Tag(DataflowOp): tag: int - variants: list[tys.TypeRow] + sum_ty: tys.Sum num_out: int | None = 1 def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Tag: return sops.Tag( parent=parent.idx, tag=self.tag, - variants=[ser_it(r) for r in self.variants], + variants=[ser_it(r) for r in self.sum_ty.variant_rows], ) def outer_signature(self) -> tys.FunctionType: return tys.FunctionType( - input=self.variants[self.tag], output=[tys.Sum(self.variants)] + input=self.sum_ty.variant_rows[self.tag], output=[self.sum_ty] ) - def __call__(self, value: Wire) -> Command: - return super().__call__(value) - class DfParentOp(Op, Protocol): def inner_signature(self) -> tys.FunctionType: ... @@ -219,15 +216,18 @@ def _set_out_types(self, types: tys.TypeRow) -> None: ... def _inputs(self) -> tys.TypeRow: ... -@dataclass() +@dataclass class DFG(DfParentOp, DataflowOp): - _signature: tys.TypeRow | tys.FunctionType + inputs: tys.TypeRow + _outputs: tys.TypeRow | None = None + + @property + def outputs(self) -> tys.TypeRow: + return _check_complete(self._outputs) @property def signature(self) -> tys.FunctionType: - if isinstance(self._signature, tys.FunctionType): - return self._signature - raise IncompleteOp() + return tys.FunctionType(self.inputs, self.outputs) @property def num_out(self) -> int | None: @@ -246,24 +246,28 @@ def outer_signature(self) -> tys.FunctionType: return self.signature def _set_out_types(self, types: tys.TypeRow) -> None: - assert isinstance(self._signature, list), "Signature has already been set." - self._signature = tys.FunctionType(self._signature, types) + self._outputs = types def _inputs(self) -> tys.TypeRow: - match self._signature: - case tys.FunctionType(input, _): - return input - case list(_): - return self._signature + return self.inputs @dataclass() class CFG(DataflowOp): - signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) + inputs: tys.TypeRow + _outputs: tys.TypeRow | None = None + + @property + def outputs(self) -> tys.TypeRow: + return _check_complete(self._outputs) + + @property + def signature(self) -> tys.FunctionType: + return tys.FunctionType(self.inputs, self.outputs) @property def num_out(self) -> int | None: - return len(self.signature.output) + return len(self.outputs) def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CFG: return sops.CFG( @@ -278,13 +282,13 @@ def outer_signature(self) -> tys.FunctionType: @dataclass class DataflowBlock(DfParentOp): inputs: tys.TypeRow - _sum_rows: list[tys.TypeRow] | None = None + _sum: tys.Sum | None = None _other_outputs: tys.TypeRow | None = None extension_delta: tys.ExtensionSet = field(default_factory=list) @property - def sum_rows(self) -> list[tys.TypeRow]: - return _check_complete(self._sum_rows) + def sum_ty(self) -> tys.Sum: + return _check_complete(self._sum) @property def other_outputs(self) -> tys.TypeRow: @@ -292,36 +296,35 @@ def other_outputs(self) -> tys.TypeRow: @property def num_out(self) -> int | None: - return len(self.sum_rows) + return len(self.sum_ty.variant_rows) def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DataflowBlock: return sops.DataflowBlock( parent=parent.idx, inputs=ser_it(self.inputs), - sum_rows=list(map(ser_it, self.sum_rows)), + sum_rows=list(map(ser_it, self.sum_ty.variant_rows)), other_outputs=ser_it(self.other_outputs), extension_delta=self.extension_delta, ) def inner_signature(self) -> tys.FunctionType: return tys.FunctionType( - input=self.inputs, output=[tys.Sum(self.sum_rows), *self.other_outputs] + input=self.inputs, output=[self.sum_ty, *self.other_outputs] ) def port_kind(self, port: InPort | OutPort) -> tys.Kind: return tys.CFKind() def _set_out_types(self, types: tys.TypeRow) -> None: - (sum_, *other) = types - assert isinstance(sum_, tys.Sum), f"Expected Sum, got {sum_}" - self._sum_rows = sum_.variant_rows + (sum_, other) = tys.get_first_sum(types) + self._sum = sum_ self._other_outputs = other def _inputs(self) -> tys.TypeRow: return self.inputs def nth_outputs(self, n: int) -> tys.TypeRow: - return [*self.sum_rows[n], *self.other_outputs] + return [*self.sum_ty.variant_rows[n], *self.other_outputs] @dataclass @@ -373,3 +376,112 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.LoadConstant: def outer_signature(self) -> tys.FunctionType: return tys.FunctionType(input=[], output=[self.type_()]) + + +@dataclass() +class Conditional(DataflowOp): + sum_ty: tys.Sum + other_inputs: tys.TypeRow + _outputs: tys.TypeRow | None = None + + @property + def outputs(self) -> tys.TypeRow: + return _check_complete(self._outputs) + + @property + def signature(self) -> tys.FunctionType: + inputs = [self.sum_ty, *self.other_inputs] + return tys.FunctionType(inputs, self.outputs) + + @property + def num_out(self) -> int | None: + return len(self.outputs) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Conditional: + return sops.Conditional( + parent=parent.idx, + sum_rows=[ser_it(r) for r in self.sum_ty.variant_rows], + other_inputs=ser_it(self.other_inputs), + outputs=ser_it(self.outputs), + ) + + def outer_signature(self) -> tys.FunctionType: + return self.signature + + def nth_inputs(self, n: int) -> tys.TypeRow: + return [*self.sum_ty.variant_rows[n], *self.other_inputs] + + +@dataclass +class Case(DfParentOp): + inputs: tys.TypeRow + _outputs: tys.TypeRow | None = None + + @property + def outputs(self) -> tys.TypeRow: + return _check_complete(self._outputs) + + @property + def num_out(self) -> int | None: + return 0 + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Case: + return sops.Case( + parent=parent.idx, signature=self.inner_signature().to_serial() + ) + + def inner_signature(self) -> tys.FunctionType: + return tys.FunctionType(self.inputs, self.outputs) + + def port_kind(self, port: InPort | OutPort) -> tys.Kind: + raise NotImplementedError("Case nodes have no external ports.") + + def _set_out_types(self, types: tys.TypeRow) -> None: + self._outputs = types + + def _inputs(self) -> tys.TypeRow: + return self.inputs + + +@dataclass +class TailLoop(DfParentOp, DataflowOp): + just_inputs: tys.TypeRow + rest: tys.TypeRow + _just_outputs: tys.TypeRow | None = None + extension_delta: tys.ExtensionSet = field(default_factory=list) + + @property + def just_outputs(self) -> tys.TypeRow: + return _check_complete(self._just_outputs) + + @property + def num_out(self) -> int | None: + return len(self.just_outputs) + len(self.rest) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.TailLoop: + return sops.TailLoop( + parent=parent.idx, + just_inputs=ser_it(self.just_inputs), + just_outputs=ser_it(self.just_outputs), + rest=ser_it(self.rest), + extension_delta=self.extension_delta, + ) + + def inner_signature(self) -> tys.FunctionType: + return tys.FunctionType( + self._inputs(), [tys.Sum([self.just_inputs, self.just_outputs]), *self.rest] + ) + + def outer_signature(self) -> tys.FunctionType: + return tys.FunctionType(self._inputs(), self.just_outputs + self.rest) + + def _set_out_types(self, types: tys.TypeRow) -> None: + (sum_, other) = tys.get_first_sum(types) + just_ins, just_outs = sum_.variant_rows + assert ( + just_ins == self.just_inputs + ), "First sum variant rows don't match TailLoop inputs." + self._just_outputs = just_outs + + def _inputs(self) -> tys.TypeRow: + return self.just_inputs + self.rest diff --git a/hugr-py/src/hugr/_tys.py b/hugr-py/src/hugr/_tys.py index 7b0fa2335..48dd43b08 100644 --- a/hugr-py/src/hugr/_tys.py +++ b/hugr-py/src/hugr/_tys.py @@ -237,6 +237,10 @@ def to_serial(self) -> stys.FunctionType: def empty(cls) -> FunctionType: return cls(input=[], output=[]) + @classmethod + def endo(cls, tys: TypeRow) -> FunctionType: + return cls(input=tys, output=tys) + def flip(self) -> FunctionType: return FunctionType(input=list(self.output), output=list(self.input)) @@ -303,3 +307,9 @@ class OrderKind: ... Kind = ValueKind | ConstKind | FunctionKind | CFKind | OrderKind + + +def get_first_sum(types: TypeRow) -> tuple[Sum, TypeRow]: + (sum_, *other) = types + assert isinstance(sum_, Sum), f"Expected Sum, got {sum_}" + return sum_, other diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index bb4ee2882..d2fedf027 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -206,7 +206,7 @@ def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: def deserialize(self) -> _ops.DataflowBlock: return _ops.DataflowBlock( inputs=deser_it(self.inputs), - _sum_rows=[deser_it(r) for r in self.sum_rows], + _sum=_tys.Sum([deser_it(r) for r in self.sum_rows]), _other_outputs=deser_it(self.other_outputs), ) @@ -347,7 +347,8 @@ def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: ) def deserialize(self) -> _ops.DFG: - return _ops.DFG(self.signature.deserialize()) + sig = self.signature.deserialize() + return _ops.DFG(sig.input, sig.output) # ------------------------------------------------ @@ -383,6 +384,13 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: self.other_inputs = list(in_types[1:]) self.outputs = list(out_types) + def deserialize(self) -> _ops.Conditional: + return _ops.Conditional( + _tys.Sum([deser_it(r) for r in self.sum_rows]), + deser_it(self.other_inputs), + deser_it(self.outputs), + ) + class Case(BaseOp): """Case ops - nodes valid inside Conditional nodes.""" @@ -396,6 +404,10 @@ def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: input=list(inputs), output=list(outputs), extension_reqs=ExtensionSet([]) ) + def deserialize(self) -> _ops.Case: + sig = self.signature.deserialize() + return _ops.Case(inputs=sig.input, _outputs=sig.output) + class TailLoop(DataflowOp): """Tail-controlled loop.""" @@ -413,6 +425,14 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: # self.just_outputs = list(out_types) self.rest = list(in_types) + def deserialize(self) -> _ops.TailLoop: + return _ops.TailLoop( + just_inputs=deser_it(self.just_inputs), + _just_outputs=deser_it(self.just_outputs), + rest=deser_it(self.rest), + extension_delta=self.extension_delta, + ) + class CFG(DataflowOp): """A dataflow node which is defined by a child CFG.""" @@ -426,7 +446,8 @@ def insert_port_types(self, inputs: TypeRow, outputs: TypeRow) -> None: ) def deserialize(self) -> _ops.CFG: - return _ops.CFG(self.signature.deserialize()) + sig = self.signature.deserialize() + return _ops.CFG(inputs=sig.input, _outputs=sig.output) ControlFlowOp = Conditional | TailLoop | CFG @@ -520,7 +541,7 @@ class Tag(DataflowOp): def deserialize(self) -> _ops.Tag: return _ops.Tag( tag=self.tag, - variants=[deser_it(v) for v in self.variants], + sum_ty=_tys.Sum([deser_it(v) for v in self.variants]), ) @@ -614,3 +635,4 @@ class OpDef(ConfiguredBaseModel, populate_by_name=True): # needed to avoid circular imports from hugr import _ops # noqa: E402 from hugr import _val # noqa: E402 +from hugr import _tys # noqa: E402 diff --git a/hugr-py/tests/test_cfg.py b/hugr-py/tests/test_cfg.py index a5beed4b9..269c48285 100644 --- a/hugr-py/tests/test_cfg.py +++ b/hugr-py/tests/test_cfg.py @@ -40,7 +40,7 @@ def test_branch() -> None: def test_nested_cfg() -> None: dfg = Dfg(tys.Bool) - cfg = dfg.add_cfg([tys.Bool], *dfg.inputs()) + cfg = dfg.add_cfg(*dfg.inputs()) build_basic_cfg(cfg) dfg.set_outputs(cfg) @@ -72,13 +72,15 @@ def test_asymm_types() -> None: cfg = Cfg([]) entry = cfg.add_entry() - int_load = entry.add_load_const(IntVal(34)) - tagged_int = entry.add(ops.Tag(0, [[INT_T], [tys.Bool]])(int_load)) + int_load = entry.load(IntVal(34)) + + sum_ty = tys.Sum([[INT_T], [tys.Bool]]) + tagged_int = entry.add(ops.Tag(0, sum_ty)(int_load)) entry.set_block_outputs(tagged_int) middle = cfg.add_successor(entry[0]) # discard the int and return the bool from entry - middle.set_single_succ_outputs(middle.add_load_const(val.TRUE)) + middle.set_single_succ_outputs(middle.load(val.TRUE)) # middle expects an int and exit expects a bool cfg.branch_exit(entry[1]) diff --git a/hugr-py/tests/test_cond_loop.py b/hugr-py/tests/test_cond_loop.py new file mode 100644 index 000000000..24381e06c --- /dev/null +++ b/hugr-py/tests/test_cond_loop.py @@ -0,0 +1,124 @@ +from hugr._cond_loop import Conditional, ConditionalError, TailLoop +from hugr._dfg import Dfg +import hugr._tys as tys +import hugr._ops as ops +import hugr._val as val +import pytest +from .test_hugr_build import INT_T, _validate, IntVal, H, Measure + +SUM_T = tys.Sum([[tys.Qubit], [tys.Qubit, INT_T]]) + + +def build_cond(h: Conditional) -> None: + with pytest.raises(ConditionalError, match="Case 2 out of possible range."): + h.add_case(2) + + case0 = h.add_case(0) + q, b = case0.inputs() + case0.set_outputs(q, b) + + case1 = h.add_case(1) + q, _i, b = case1.inputs() + case1.set_outputs(q, b) + + +def test_cond() -> None: + h = Conditional(SUM_T, [tys.Bool]) + build_cond(h) + _validate(h.hugr) + + +def test_nested_cond() -> None: + h = Dfg(tys.Qubit) + (q,) = h.inputs() + tagged_q = h.add(ops.Tag(0, SUM_T)(q)) + cond = h.add_conditional(tagged_q, h.load(val.TRUE)) + build_cond(cond) + h.set_outputs(*cond[:2]) + _validate(h.hugr) + + # build then insert + con = Conditional(SUM_T, [tys.Bool]) + build_cond(con) + + h = Dfg(tys.Qubit) + (q,) = h.inputs() + tagged_q = h.add(ops.Tag(0, SUM_T)(q)) + cond_n = h.insert_conditional(con, tagged_q, h.load(val.TRUE)) + h.set_outputs(*cond_n[:2]) + _validate(h.hugr) + + +def test_if_else() -> None: + # apply an H if a bool is true. + h = Dfg(tys.Qubit) + (q,) = h.inputs() + if_ = h.add_if(h.load(val.TRUE), q) + + if_.set_outputs(if_.add(H(if_.input_node[0]))) + + else_ = if_.add_else() + else_.set_outputs(else_.input_node[0]) + + cond = else_.finish() + h.set_outputs(cond) + + _validate(h.hugr) + + +def test_tail_loop() -> None: + # apply H while measure is true + def build_tl(tl: TailLoop) -> None: + q, b = tl.add(Measure(tl.add(H(tl.input_node[0]))))[:] + + tl.set_loop_outputs(b, q) + + h = Dfg(tys.Qubit) + (q,) = h.inputs() + + tl = h.add_tail_loop([], [q]) + build_tl(tl) + h.set_outputs(tl) + + _validate(h.hugr) + + # build then insert + tl = TailLoop([], [tys.Qubit]) + build_tl(tl) + + h = Dfg(tys.Qubit) + (q,) = h.inputs() + tl_n = h.insert_tail_loop(tl, q) + h.set_outputs(tl_n) + _validate(h.hugr) + + +def test_complex_tail_loop() -> None: + h = Dfg(tys.Qubit) + (q,) = h.inputs() + + # loop passes qubit to itself, and a bool as in-out + tl = h.add_tail_loop([q], [h.load(val.TRUE)]) + q, b = tl.inputs() + + # if b is true, return first variant (just qubit) + if_ = tl.add_if(b, q) + (q,) = if_.inputs() + tagged_q = if_.add(ops.Tag(0, SUM_T)(q)) + if_.set_outputs(tagged_q) + + # else return second variant (qubit, int) + else_ = if_.add_else() + (q,) = else_.inputs() + tagged_q_i = else_.add(ops.Tag(1, SUM_T)(q, else_.load(IntVal(1)))) + else_.set_outputs(tagged_q_i) + + # finish with Sum output from if-else, and bool from inputs + tl.set_loop_outputs(else_.finish(), b) + + # loop returns [qubit, int, bool] + h.set_outputs(*tl[:3]) + + _validate(h.hugr, True) + + # TODO rewrite with context managers diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 04d908f64..14a44d12b 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -44,9 +44,7 @@ class LogicOps(Custom): class NotDef(LogicOps): num_out: int | None = 1 op_name: str = "Not" - signature: tys.FunctionType = field( - default_factory=lambda: tys.FunctionType(input=[tys.Bool], output=[tys.Bool]) - ) + signature: tys.FunctionType = tys.FunctionType.endo([tys.Bool]) def __call__(self, a: Wire) -> Command: return super().__call__(a) @@ -55,6 +53,37 @@ def __call__(self, a: Wire) -> Command: Not = NotDef() +@dataclass +class QuantumOps(Custom): + extension: tys.ExtensionId = "tket2.quantum" + + +@dataclass +class OneQbGate(QuantumOps): + op_name: str + num_out: int | None = 1 + signature: tys.FunctionType = tys.FunctionType.endo([tys.Qubit]) + + def __call__(self, q: Wire) -> Command: + return super().__call__(q) + + +H = OneQbGate("H") + + +@dataclass +class MeasureDef(QuantumOps): + op_name: str = "Measure" + num_out: int | None = 2 + signature: tys.FunctionType = tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool]) + + def __call__(self, q: Wire) -> Command: + return super().__call__(q) + + +Measure = MeasureDef() + + @dataclass class IntOps(Custom): extension: tys.ExtensionId = "arithmetic.int" @@ -281,6 +310,6 @@ def test_ancestral_sibling(): ) def test_vals(val: val.Value): d = Dfg() - d.set_outputs(d.add_load_const(val)) + d.set_outputs(d.load(val)) _validate(d.hugr)