diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py index d26eec4bc..4277edd10 100644 --- a/hugr-py/src/hugr/_cfg.py +++ b/hugr-py/src/hugr/_cfg.py @@ -8,15 +8,16 @@ from ._exceptions import NoSiblingAncestor, NotInSameCfg, MismatchedExit from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire from ._tys import FunctionType, TypeRow, Type +import hugr._val as val class Block(_DfBase[ops.DataflowBlock]): def set_block_outputs(self, branching: Wire, *other_outputs: Wire) -> None: self.set_outputs(branching, *other_outputs) - def set_single_successor_outputs(self, *outputs: Wire) -> None: - # TODO requires constants - raise NotImplementedError + def set_single_succ_outputs(self, *outputs: Wire) -> None: + u = self.add_load_const(val.Unit) + self.set_outputs(u, *outputs) def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type: src = p.out_port() diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index 7e89b91bc..8d4d1447f 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -1,19 +1,20 @@ from __future__ import annotations + from dataclasses import dataclass, replace from typing import ( - Iterable, TYPE_CHECKING, + Iterable, TypeVar, ) -from ._hugr import Hugr, Node, Wire, OutPort, ParentBuilder from typing_extensions import Self + import hugr._ops as ops -from hugr._tys import TypeRow +import hugr._val as val +from hugr._tys import Type, TypeRow from ._exceptions import NoSiblingAncestor -from ._hugr import ToNode -from hugr._tys import Type +from ._hugr import Hugr, Node, OutPort, ParentBuilder, ToNode, Wire if TYPE_CHECKING: from ._cfg import Cfg @@ -113,6 +114,21 @@ def add_state_order(self, src: Node, dst: Node) -> None: # adds edge to the right of all existing edges self.hugr.add_link(src.out(-1), dst.inp(-1)) + def add_const(self, val: val.Value) -> Node: + return self.hugr.add_const(val, self.parent_node) + + def load_const(self, const_node: ToNode) -> Node: + const_op = self.hugr._get_typed_op(const_node, ops.Const) + load_op = ops.LoadConst(const_op.val.type_()) + + load = self.add(load_op()) + self.hugr.add_link(const_node.out_port(), load.inp(0)) + + return load + + def add_load_const(self, val: val.Value) -> Node: + return self.load_const(self.add_const(val)) + def _wire_up(self, node: Node, ports: Iterable[Wire]): tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)] if isinstance(op := self.hugr[node].op, ops.PartialOp): diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index a54d35e6a..a13cb3d1d 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -17,8 +17,9 @@ from typing_extensions import Self -from hugr._ops import Op, DataflowOp +from hugr._ops import Op, DataflowOp, Const from hugr._tys import Type, Kind +from hugr._val import Value from hugr.serialization.ops import OpType as SerialOp from hugr.serialization.serial_hugr import SerialHugr from hugr.utils import BiMap @@ -228,6 +229,9 @@ def add_node( parent = parent or self.root return self._add_node(op, parent, num_outs) + def add_const(self, value: Value, parent: ToNode | None = None) -> Node: + return self.add_node(Const(value), parent) + def delete_node(self, node: ToNode) -> NodeData | None: node = node.to_node() parent = self[node].parent diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index 1a390eeb1..89c6d285c 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -1,11 +1,12 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Generic, Protocol, TypeVar, TYPE_CHECKING, runtime_checkable +from typing import Protocol, TYPE_CHECKING, runtime_checkable, TypeVar from hugr.serialization.ops import BaseOp import hugr.serialization.ops as sops from hugr.utils import ser_it import hugr._tys as tys +import hugr._val as val from ._exceptions import IncompleteOp if TYPE_CHECKING: @@ -55,23 +56,6 @@ class Command: incoming: list[Wire] -T = TypeVar("T", bound=BaseOp) - - -@dataclass() -class SerWrap(Op, Generic[T]): - # catch all for serial ops that don't have a corresponding Op class - _serial_op: T - - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> T: - root = self._serial_op.model_copy() - root.parent = parent.idx - return root - - def port_kind(self, port: InPort | OutPort) -> tys.Kind: - raise NotImplementedError - - @dataclass() class Input(DataflowOp): types: tys.TypeRow @@ -304,9 +288,7 @@ def sum_rows(self) -> list[tys.TypeRow]: @property def other_outputs(self) -> tys.TypeRow: - if self._other_outputs is None: - raise IncompleteOp() - return self._other_outputs + return _check_complete(self._other_outputs) @property def num_out(self) -> int | None: @@ -359,3 +341,35 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.ExitBlock: def port_kind(self, port: InPort | OutPort) -> tys.Kind: return tys.CFKind() + + +@dataclass +class Const(Op): + val: val.Value + num_out: int | None = 1 + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Const: + return sops.Const( + parent=parent.idx, + v=self.val.to_serial_root(), + ) + + def port_kind(self, port: InPort | OutPort) -> tys.Kind: + return tys.ConstKind(self.val.type_()) + + +@dataclass +class LoadConst(DataflowOp): + typ: tys.Type | None = None + + def type_(self) -> tys.Type: + return _check_complete(self.typ) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.LoadConstant: + return sops.LoadConstant( + parent=parent.idx, + datatype=self.type_().to_serial_root(), + ) + + def outer_signature(self) -> tys.FunctionType: + return tys.FunctionType(input=[], output=[self.type_()]) diff --git a/hugr-py/src/hugr/_val.py b/hugr-py/src/hugr/_val.py new file mode 100644 index 000000000..d2a6277eb --- /dev/null +++ b/hugr-py/src/hugr/_val.py @@ -0,0 +1,103 @@ +from __future__ import annotations +from dataclasses import dataclass, field +from typing import Any, Protocol, runtime_checkable, TYPE_CHECKING +import hugr.serialization.ops as sops +import hugr.serialization.tys as stys +import hugr._tys as tys +from hugr.utils import ser_it + +if TYPE_CHECKING: + from hugr._hugr import Hugr + + +@runtime_checkable +class Value(Protocol): + def to_serial(self) -> sops.BaseValue: ... + def to_serial_root(self) -> sops.Value: + return sops.Value(root=self.to_serial()) # type: ignore[arg-type] + + def type_(self) -> tys.Type: ... + + +@dataclass +class Sum(Value): + tag: int + typ: tys.Sum + vals: list[Value] + + def type_(self) -> tys.Sum: + return self.typ + + def to_serial(self) -> sops.SumValue: + return sops.SumValue( + tag=self.tag, + typ=stys.SumType(root=self.type_().to_serial()), + vs=ser_it(self.vals), + ) + + +def bool_value(b: bool) -> Sum: + return Sum( + tag=int(b), + typ=tys.Bool, + vals=[], + ) + + +Unit = Sum(0, tys.Unit, []) +TRUE = bool_value(True) +FALSE = bool_value(False) + + +@dataclass +class Tuple(Value): + vals: list[Value] + + def type_(self) -> tys.Tuple: + return tys.Tuple(*(v.type_() for v in self.vals)) + + def to_serial(self) -> sops.TupleValue: + return sops.TupleValue( + vs=ser_it(self.vals), + ) + + +@dataclass +class Function(Value): + body: Hugr + + def type_(self) -> tys.FunctionType: + return self.body.root_op().inner_signature() + + def to_serial(self) -> sops.FunctionValue: + return sops.FunctionValue( + hugr=self.body.to_serial(), + ) + + +@dataclass +class Extension(Value): + name: str + typ: tys.Type + val: Any + extensions: tys.ExtensionSet = field(default_factory=tys.ExtensionSet) + + def type_(self) -> tys.Type: + return self.typ + + def to_serial(self) -> sops.ExtensionValue: + return sops.ExtensionValue( + typ=self.typ.to_serial_root(), + value=sops.CustomConst(c=self.name, v=self.val), + extensions=self.extensions, + ) + + +class ExtensionValue(Value, Protocol): + def to_value(self) -> Extension: ... + + def type_(self) -> tys.Type: + return self.to_value().type_() + + def to_serial(self) -> sops.ExtensionValue: + return self.to_value().to_serial() diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index 049aa1d9f..bb4ee2882 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -1,7 +1,7 @@ from __future__ import annotations import inspect import sys -from abc import ABC +from abc import ABC, abstractmethod from typing import Any, Literal from pydantic import Field, RootModel, ConfigDict @@ -22,6 +22,7 @@ ) from hugr.utils import deser_it + NodeID = int @@ -44,7 +45,7 @@ def display_name(self) -> str: def deserialize(self) -> _ops.Op: """Deserializes the model into the corresponding Op.""" - return _ops.SerWrap(self) + raise NotImplementedError # ---------------------------------------------------------- @@ -80,36 +81,54 @@ class CustomConst(ConfiguredBaseModel): v: Any -class ExtensionValue(ConfiguredBaseModel): +class BaseValue(ABC, ConfiguredBaseModel): + @abstractmethod + def deserialize(self) -> _val.Value: ... + + +class ExtensionValue(BaseValue): """An extension constant value, that can check it is of a given [CustomType].""" - v: Literal["Extension"] = Field("Extension", title="ValueTag") + v: Literal["Extension"] = Field(default="Extension", title="ValueTag") extensions: ExtensionSet typ: Type value: CustomConst + def deserialize(self) -> _val.Value: + return _val.Extension(self.value.c, self.typ.deserialize(), self.value.v) -class FunctionValue(ConfiguredBaseModel): + +class FunctionValue(BaseValue): """A higher-order function value.""" - v: Literal["Function"] = Field("Function", title="ValueTag") - hugr: Any # TODO + v: Literal["Function"] = Field(default="Function", title="ValueTag") + hugr: Any + + def deserialize(self) -> _val.Value: + from hugr._hugr import Hugr + from hugr.serialization.serial_hugr import SerialHugr + + # pydantic stores the serialized dictionary because of the "Any" annotation + return _val.Function(Hugr.from_serial(SerialHugr(**self.hugr))) -class TupleValue(ConfiguredBaseModel): +class TupleValue(BaseValue): """A constant tuple value.""" - v: Literal["Tuple"] = Field("Tuple", title="ValueTag") + v: Literal["Tuple"] = Field(default="Tuple", title="ValueTag") vs: list["Value"] + def deserialize(self) -> _val.Value: + return _val.Tuple(deser_it((v.root for v in self.vs))) + -class SumValue(ConfiguredBaseModel): +class SumValue(BaseValue): """A Sum variant For any Sum type where this value meets the type of the variant indicated by the tag """ - v: Literal["Sum"] = Field("Sum", title="ValueTag") + v: Literal["Sum"] = Field(default="Sum", title="ValueTag") tag: int typ: SumType vs: list["Value"] @@ -122,6 +141,11 @@ class SumValue(ConfiguredBaseModel): } ) + def deserialize(self) -> _val.Value: + return _val.Sum( + self.tag, self.typ.deserialize(), deser_it((v.root for v in self.vs)) + ) + class Value(RootModel): """A constant Value.""" @@ -132,6 +156,9 @@ class Value(RootModel): model_config = ConfigDict(json_schema_extra={"required": ["v"]}) + def deserialize(self) -> _val.Value: + return self.root.deserialize() + class Const(BaseOp): """A Const operation definition.""" @@ -139,6 +166,9 @@ class Const(BaseOp): op: Literal["Const"] = "Const" v: Value = Field() + def deserialize(self) -> _ops.Const: + return _ops.Const(self.v.deserialize()) + # ----------------------------------------------- # --------------- BasicBlock types ------------------ @@ -173,6 +203,13 @@ def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: # Needed to avoid random '\n's in the pydantic description + def deserialize(self) -> _ops.DataflowBlock: + return _ops.DataflowBlock( + inputs=deser_it(self.inputs), + _sum_rows=[deser_it(r) for r in self.sum_rows], + _other_outputs=deser_it(self.other_outputs), + ) + model_config = ConfigDict( json_schema_extra={ "description": "A CFG basic block node. The signature is that of the internal Dataflow graph.", @@ -194,6 +231,9 @@ class ExitBlock(BaseOp): } ) + def deserialize(self) -> _ops.ExitBlock: + return _ops.ExitBlock(deser_it(self.cfg_outputs)) + # --------------------------------------------- # --------------- DataflowOp ------------------ @@ -282,6 +322,9 @@ class LoadConstant(DataflowOp): op: Literal["LoadConstant"] = "LoadConstant" datatype: Type + def deserialize(self) -> _ops.LoadConst: + return _ops.LoadConst(self.datatype.deserialize()) + class LoadFunction(DataflowOp): """Load a static function in to the local dataflow graph.""" @@ -382,6 +425,9 @@ def insert_port_types(self, inputs: TypeRow, outputs: TypeRow) -> None: input=list(inputs), output=list(outputs), extension_reqs=ExtensionSet([]) ) + def deserialize(self) -> _ops.CFG: + return _ops.CFG(self.signature.deserialize()) + ControlFlowOp = Conditional | TailLoop | CFG @@ -471,6 +517,12 @@ class Tag(DataflowOp): tag: int # The variant to create. variants: list[TypeRow] # The variants of the sum type. + def deserialize(self) -> _ops.Tag: + return _ops.Tag( + tag=self.tag, + variants=[deser_it(v) for v in self.variants], + ) + class Lift(DataflowOp): """Fixes some TypeParams of a polymorphic type by providing TypeArgs.""" @@ -559,4 +611,6 @@ class OpDef(ConfiguredBaseModel, populate_by_name=True): tys_model_rebuild(dict(classes)) -from hugr import _ops # noqa: E402 # needed to avoid circular imports +# needed to avoid circular imports +from hugr import _ops # noqa: E402 +from hugr import _val # noqa: E402 diff --git a/hugr-py/tests/test_cfg.py b/hugr-py/tests/test_cfg.py index c8fc5f511..a5beed4b9 100644 --- a/hugr-py/tests/test_cfg.py +++ b/hugr-py/tests/test_cfg.py @@ -1,34 +1,35 @@ from hugr._cfg import Cfg import hugr._tys as tys +import hugr._val as val from hugr._dfg import Dfg import hugr._ops as ops -from .test_hugr_build import _validate, INT_T, DivMod +from .test_hugr_build import _validate, INT_T, DivMod, IntVal def build_basic_cfg(cfg: Cfg) -> None: entry = cfg.add_entry() - entry.set_block_outputs(*entry.inputs()) + entry.set_single_succ_outputs(*entry.inputs()) cfg.branch(entry[0], cfg.exit) def test_basic_cfg() -> None: - cfg = Cfg([tys.Unit, tys.Bool]) + cfg = Cfg([tys.Bool]) build_basic_cfg(cfg) _validate(cfg.hugr) def test_branch() -> None: - cfg = Cfg([tys.Bool, tys.Unit, INT_T]) + cfg = Cfg([tys.Bool, INT_T]) entry = cfg.add_entry() entry.set_block_outputs(*entry.inputs()) middle_1 = cfg.add_successor(entry[0]) - middle_1.set_block_outputs(*middle_1.inputs()) + middle_1.set_single_succ_outputs(*middle_1.inputs()) middle_2 = cfg.add_successor(entry[1]) - u, i = middle_2.inputs() + (i,) = middle_2.inputs() n = middle_2.add(DivMod(i, i)) - middle_2.set_block_outputs(u, n[0]) + middle_2.set_single_succ_outputs(n[0]) cfg.branch_exit(middle_1[0]) cfg.branch_exit(middle_2[0]) @@ -37,9 +38,9 @@ def test_branch() -> None: def test_nested_cfg() -> None: - dfg = Dfg(tys.Unit, tys.Bool) + dfg = Dfg(tys.Bool) - cfg = dfg.add_cfg([tys.Unit, tys.Bool], *dfg.inputs()) + cfg = dfg.add_cfg([tys.Bool], *dfg.inputs()) build_basic_cfg(cfg) dfg.set_outputs(cfg) @@ -68,16 +69,16 @@ def test_dom_edge() -> None: def test_asymm_types() -> None: # test different types going to entry block's susccessors - cfg = Cfg([tys.Bool, tys.Unit, INT_T]) + cfg = Cfg([]) entry = cfg.add_entry() - b, u, i = entry.inputs() - tagged_int = entry.add(ops.Tag(0, [[INT_T], [tys.Bool]])(i)) + int_load = entry.add_load_const(IntVal(34)) + tagged_int = entry.add(ops.Tag(0, [[INT_T], [tys.Bool]])(int_load)) entry.set_block_outputs(tagged_int) middle = cfg.add_successor(entry[0]) # discard the int and return the bool from entry - middle.set_block_outputs(u, b) + middle.set_single_succ_outputs(middle.add_load_const(val.TRUE)) # middle expects an int and exit expects a bool cfg.branch_exit(entry[1]) diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index ff7a2759b..04d908f64 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -9,6 +9,7 @@ import hugr._ops as ops from hugr.serialization import SerialHugr import hugr._tys as tys +import hugr._val as val import pytest import json @@ -25,6 +26,14 @@ def int_t(width: int) -> tys.Opaque: INT_T = int_t(5) +@dataclass +class IntVal(val.ExtensionValue): + v: int + + def to_value(self) -> val.Extension: + return val.Extension("int", INT_T, self.v) + + @dataclass class LogicOps(Custom): extension: tys.ExtensionId = "logic" @@ -119,12 +128,15 @@ def test_stable_indices(): assert h._free_nodes == [] -def test_simple_id(): +def simple_id() -> Dfg: h = Dfg(tys.Qubit, tys.Qubit) a, b = h.inputs() h.set_outputs(a, b) + return h - _validate(h.hugr) + +def test_simple_id(): + _validate(simple_id().hugr) def test_multiport(): @@ -257,3 +269,18 @@ def test_ancestral_sibling(): nt = nested.add(Not(a)) assert _ancestral_sibling(h.hugr, h.input_node, nt) == nested.parent_node + + +@pytest.mark.parametrize( + "val", + [ + val.Function(simple_id().hugr), + val.Sum(1, tys.Sum([[INT_T], [tys.Bool, INT_T]]), [IntVal(34)]), + val.Tuple([val.TRUE, IntVal(23)]), + ], +) +def test_vals(val: val.Value): + d = Dfg() + d.set_outputs(d.add_load_const(val)) + + _validate(d.hugr)