Skip to content

Commit

Permalink
feat(hugr-py): add values and constants (#1203)
Browse files Browse the repository at this point in the history
Closes #1202
  • Loading branch information
ss2165 authored Jun 20, 2024
1 parent 04ed329 commit f7ea178
Show file tree
Hide file tree
Showing 8 changed files with 277 additions and 57 deletions.
7 changes: 4 additions & 3 deletions hugr-py/src/hugr/_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@
from ._exceptions import NoSiblingAncestor, NotInSameCfg, MismatchedExit
from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire
from ._tys import FunctionType, TypeRow, Type
import hugr._val as val


class Block(_DfBase[ops.DataflowBlock]):
def set_block_outputs(self, branching: Wire, *other_outputs: Wire) -> None:
self.set_outputs(branching, *other_outputs)

def set_single_successor_outputs(self, *outputs: Wire) -> None:
# TODO requires constants
raise NotImplementedError
def set_single_succ_outputs(self, *outputs: Wire) -> None:
u = self.add_load_const(val.Unit)
self.set_outputs(u, *outputs)

def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type:
src = p.out_port()
Expand Down
26 changes: 21 additions & 5 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from __future__ import annotations

from dataclasses import dataclass, replace
from typing import (
Iterable,
TYPE_CHECKING,
Iterable,
TypeVar,
)
from ._hugr import Hugr, Node, Wire, OutPort, ParentBuilder

from typing_extensions import Self

import hugr._ops as ops
from hugr._tys import TypeRow
import hugr._val as val
from hugr._tys import Type, TypeRow

from ._exceptions import NoSiblingAncestor
from ._hugr import ToNode
from hugr._tys import Type
from ._hugr import Hugr, Node, OutPort, ParentBuilder, ToNode, Wire

if TYPE_CHECKING:
from ._cfg import Cfg
Expand Down Expand Up @@ -113,6 +114,21 @@ def add_state_order(self, src: Node, dst: Node) -> None:
# adds edge to the right of all existing edges
self.hugr.add_link(src.out(-1), dst.inp(-1))

def add_const(self, val: val.Value) -> Node:
return self.hugr.add_const(val, self.parent_node)

def load_const(self, const_node: ToNode) -> Node:
const_op = self.hugr._get_typed_op(const_node, ops.Const)
load_op = ops.LoadConst(const_op.val.type_())

load = self.add(load_op())
self.hugr.add_link(const_node.out_port(), load.inp(0))

return load

def add_load_const(self, val: val.Value) -> Node:
return self.load_const(self.add_const(val))

def _wire_up(self, node: Node, ports: Iterable[Wire]):
tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)]
if isinstance(op := self.hugr[node].op, ops.PartialOp):
Expand Down
6 changes: 5 additions & 1 deletion hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

from typing_extensions import Self

from hugr._ops import Op, DataflowOp
from hugr._ops import Op, DataflowOp, Const
from hugr._tys import Type, Kind
from hugr._val import Value
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.serial_hugr import SerialHugr
from hugr.utils import BiMap
Expand Down Expand Up @@ -228,6 +229,9 @@ def add_node(
parent = parent or self.root
return self._add_node(op, parent, num_outs)

def add_const(self, value: Value, parent: ToNode | None = None) -> Node:
return self.add_node(Const(value), parent)

def delete_node(self, node: ToNode) -> NodeData | None:
node = node.to_node()
parent = self[node].parent
Expand Down
56 changes: 35 additions & 21 deletions hugr-py/src/hugr/_ops.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Generic, Protocol, TypeVar, TYPE_CHECKING, runtime_checkable
from typing import Protocol, TYPE_CHECKING, runtime_checkable, TypeVar
from hugr.serialization.ops import BaseOp
import hugr.serialization.ops as sops
from hugr.utils import ser_it
import hugr._tys as tys
import hugr._val as val
from ._exceptions import IncompleteOp

if TYPE_CHECKING:
Expand Down Expand Up @@ -55,23 +56,6 @@ class Command:
incoming: list[Wire]


T = TypeVar("T", bound=BaseOp)


@dataclass()
class SerWrap(Op, Generic[T]):
# catch all for serial ops that don't have a corresponding Op class
_serial_op: T

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> T:
root = self._serial_op.model_copy()
root.parent = parent.idx
return root

def port_kind(self, port: InPort | OutPort) -> tys.Kind:
raise NotImplementedError


@dataclass()
class Input(DataflowOp):
types: tys.TypeRow
Expand Down Expand Up @@ -304,9 +288,7 @@ def sum_rows(self) -> list[tys.TypeRow]:

@property
def other_outputs(self) -> tys.TypeRow:
if self._other_outputs is None:
raise IncompleteOp()
return self._other_outputs
return _check_complete(self._other_outputs)

@property
def num_out(self) -> int | None:
Expand Down Expand Up @@ -359,3 +341,35 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.ExitBlock:

def port_kind(self, port: InPort | OutPort) -> tys.Kind:
return tys.CFKind()


@dataclass
class Const(Op):
val: val.Value
num_out: int | None = 1

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Const:
return sops.Const(
parent=parent.idx,
v=self.val.to_serial_root(),
)

def port_kind(self, port: InPort | OutPort) -> tys.Kind:
return tys.ConstKind(self.val.type_())


@dataclass
class LoadConst(DataflowOp):
typ: tys.Type | None = None

def type_(self) -> tys.Type:
return _check_complete(self.typ)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.LoadConstant:
return sops.LoadConstant(
parent=parent.idx,
datatype=self.type_().to_serial_root(),
)

def outer_signature(self) -> tys.FunctionType:
return tys.FunctionType(input=[], output=[self.type_()])
103 changes: 103 additions & 0 deletions hugr-py/src/hugr/_val.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Protocol, runtime_checkable, TYPE_CHECKING
import hugr.serialization.ops as sops
import hugr.serialization.tys as stys
import hugr._tys as tys
from hugr.utils import ser_it

if TYPE_CHECKING:
from hugr._hugr import Hugr


@runtime_checkable
class Value(Protocol):
def to_serial(self) -> sops.BaseValue: ...
def to_serial_root(self) -> sops.Value:
return sops.Value(root=self.to_serial()) # type: ignore[arg-type]

def type_(self) -> tys.Type: ...


@dataclass
class Sum(Value):
tag: int
typ: tys.Sum
vals: list[Value]

def type_(self) -> tys.Sum:
return self.typ

def to_serial(self) -> sops.SumValue:
return sops.SumValue(
tag=self.tag,
typ=stys.SumType(root=self.type_().to_serial()),
vs=ser_it(self.vals),
)


def bool_value(b: bool) -> Sum:
return Sum(
tag=int(b),
typ=tys.Bool,
vals=[],
)


Unit = Sum(0, tys.Unit, [])
TRUE = bool_value(True)
FALSE = bool_value(False)


@dataclass
class Tuple(Value):
vals: list[Value]

def type_(self) -> tys.Tuple:
return tys.Tuple(*(v.type_() for v in self.vals))

def to_serial(self) -> sops.TupleValue:
return sops.TupleValue(
vs=ser_it(self.vals),
)


@dataclass
class Function(Value):
body: Hugr

def type_(self) -> tys.FunctionType:
return self.body.root_op().inner_signature()

def to_serial(self) -> sops.FunctionValue:
return sops.FunctionValue(
hugr=self.body.to_serial(),
)


@dataclass
class Extension(Value):
name: str
typ: tys.Type
val: Any
extensions: tys.ExtensionSet = field(default_factory=tys.ExtensionSet)

def type_(self) -> tys.Type:
return self.typ

def to_serial(self) -> sops.ExtensionValue:
return sops.ExtensionValue(
typ=self.typ.to_serial_root(),
value=sops.CustomConst(c=self.name, v=self.val),
extensions=self.extensions,
)


class ExtensionValue(Value, Protocol):
def to_value(self) -> Extension: ...

def type_(self) -> tys.Type:
return self.to_value().type_()

def to_serial(self) -> sops.ExtensionValue:
return self.to_value().to_serial()
Loading

0 comments on commit f7ea178

Please sign in to comment.