Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hugr-py): add CallIndirect, LoadFunction, Lift, Alias #1218

Merged
merged 3 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 39 additions & 11 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,16 @@

import hugr._ops as ops
import hugr._val as val
from hugr._tys import Type, TypeRow, get_first_sum, FunctionType, TypeArg, FunctionKind
from hugr._tys import (
Type,
TypeRow,
get_first_sum,
FunctionType,
TypeArg,
FunctionKind,
PolyFuncType,
ExtensionSet,
)

from ._exceptions import NoSiblingAncestor
from ._hugr import Hugr, ParentBuilder
Expand Down Expand Up @@ -170,15 +179,9 @@
func: ToNode,
*args: Wire,
instantiation: FunctionType | None = None,
type_args: list[TypeArg] | None = None,
type_args: Sequence[TypeArg] | None = None,
) -> Node:
f_op = self.hugr[func]
f_kind = f_op.op.port_kind(func.out(0))
match f_kind:
case FunctionKind(sig):
signature = sig
case _:
raise ValueError("Expected 'func' to be a function")
signature = self._fn_sig(func)
call_op = ops.Call(signature, instantiation, type_args)
call_n = self.hugr.add_node(call_op, self.parent_node, call_op.num_out)
self.hugr.add_link(func.out(0), call_n.inp(call_op.function_port_offset()))
Expand All @@ -187,6 +190,29 @@

return call_n

def load_function(
self,
func: ToNode,
instantiation: FunctionType | None = None,
type_args: Sequence[TypeArg] | None = None,
) -> Node:
signature = self._fn_sig(func)
load_op = ops.LoadFunc(signature, instantiation, type_args)
load_n = self.hugr.add_node(load_op, self.parent_node)
self.hugr.add_link(func.out(0), load_n.inp(0))

return load_n

def _fn_sig(self, func: ToNode) -> PolyFuncType:
f_op = self.hugr[func]
f_kind = f_op.op.port_kind(func.out(0))
match f_kind:
case FunctionKind(sig):
signature = sig
case _:
raise ValueError("Expected 'func' to be a function")

Check warning on line 213 in hugr-py/src/hugr/_dfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_dfg.py#L212-L213

Added lines #L212 - L213 were not covered by tests
return signature

def _wire_up(self, node: Node, ports: Iterable[Wire]) -> TypeRow:
tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)]
if isinstance(op := self.hugr[node].op, ops.PartialOp):
Expand All @@ -212,8 +238,10 @@


class Dfg(_DfBase[ops.DFG]):
def __init__(self, *input_types: Type) -> None:
parent_op = ops.DFG(list(input_types))
def __init__(
self, *input_types: Type, extension_delta: ExtensionSet | None = None
) -> None:
parent_op = ops.DFG(list(input_types), None, extension_delta or [])
super().__init__(parent_op)


Expand Down
8 changes: 7 additions & 1 deletion hugr-py/src/hugr/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ._dfg import _DfBase
from hugr._node_port import Node
from ._hugr import Hugr
from ._tys import TypeRow, TypeParam, PolyFuncType
from ._tys import TypeRow, TypeParam, PolyFuncType, Type, TypeBound


@dataclass
Expand Down Expand Up @@ -47,3 +47,9 @@ def declare_function(self, name: str, signature: PolyFuncType) -> Node:

def add_const(self, value: val.Value) -> Node:
return self.hugr.add_node(ops.Const(value), self.hugr.root)

def add_alias_defn(self, name: str, ty: Type) -> Node:
return self.hugr.add_node(ops.AliasDefn(name, ty), self.hugr.root)

def add_alias_decl(self, name: str, bound: TypeBound) -> Node:
return self.hugr.add_node(ops.AliasDecl(name, bound), self.hugr.root)
202 changes: 185 additions & 17 deletions hugr-py/src/hugr/_ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Protocol, TYPE_CHECKING, runtime_checkable, TypeVar
from typing import Protocol, TYPE_CHECKING, Sequence, runtime_checkable, TypeVar
from hugr.serialization.ops import BaseOp
import hugr.serialization.ops as sops
from hugr.utils import ser_it
Expand Down Expand Up @@ -233,14 +233,15 @@
class DFG(DfParentOp, DataflowOp):
inputs: tys.TypeRow
_outputs: tys.TypeRow | None = None
extension_delta: tys.ExtensionSet = field(default_factory=list)

@property
def outputs(self) -> tys.TypeRow:
return _check_complete(self._outputs)

@property
def signature(self) -> tys.FunctionType:
return tys.FunctionType(self.inputs, self.outputs)
return tys.FunctionType(self.inputs, self.outputs, self.extension_delta)

@property
def num_out(self) -> int | None:
Expand Down Expand Up @@ -381,6 +382,7 @@
@dataclass
class LoadConst(DataflowOp):
typ: tys.Type | None = None
num_out: int | None = 1

def type_(self) -> tys.Type:
return _check_complete(self.typ)
Expand Down Expand Up @@ -588,6 +590,25 @@
pass


def _fn_instantiation(
signature: tys.PolyFuncType,
instantiation: tys.FunctionType | None = None,
type_args: Sequence[tys.TypeArg] | None = None,
) -> tuple[tys.FunctionType, list[tys.TypeArg]]:
if len(signature.params) == 0:
return signature.body, []

else:
# TODO substitute type args into signature to get instantiation
if instantiation is None:
raise NoConcreteFunc("Missing instantiation for polymorphic function.")
type_args = type_args or []

if len(signature.params) != len(type_args):
raise NoConcreteFunc("Mismatched number of type arguments.")

Check warning on line 608 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L608

Added line #L608 was not covered by tests
return instantiation, list(type_args)


@dataclass
class Call(Op):
signature: tys.PolyFuncType
Expand All @@ -598,23 +619,12 @@
self,
signature: tys.PolyFuncType,
instantiation: tys.FunctionType | None = None,
type_args: list[tys.TypeArg] | None = None,
type_args: Sequence[tys.TypeArg] | None = None,
) -> None:
self.signature = signature
if len(signature.params) == 0:
self.instantiation = signature.body
self.type_args = []

else:
# TODO substitute type args into signature to get instantiation
if instantiation is None:
raise NoConcreteFunc("Missing instantiation for polymorphic function.")
type_args = type_args or []

if len(signature.params) != len(type_args):
raise NoConcreteFunc("Mismatched number of type arguments.")
self.instantiation = instantiation
self.type_args = type_args
self.instantiation, self.type_args = _fn_instantiation(
signature, instantiation, type_args
)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Call:
return sops.Call(
Expand All @@ -637,3 +647,161 @@
return tys.FunctionKind(self.signature)
case _:
return tys.ValueKind(_sig_port_type(self.instantiation, port))


@dataclass()
class CallIndirectDef(DataflowOp, PartialOp):
_signature: tys.FunctionType | None = None

@property
def num_out(self) -> int | None:
return len(self.signature.output)

@property
def signature(self) -> tys.FunctionType:
return _check_complete(self._signature)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CallIndirect:
return sops.CallIndirect(
parent=parent.idx,
signature=self.signature.to_serial(),
)

def __call__(self, function: Wire, *args: Wire) -> Command: # type: ignore[override]
return super().__call__(function, *args)

def outer_signature(self) -> tys.FunctionType:
sig = self.signature

return tys.FunctionType(input=[sig, *sig.input], output=sig.output)

def set_in_types(self, types: tys.TypeRow) -> None:
func_sig, *_ = types
assert isinstance(
func_sig, tys.FunctionType
), f"Expected function type, got {func_sig}"
self._signature = func_sig


# rename to eval?
CallIndirect = CallIndirectDef()


@dataclass
class LoadFunc(DataflowOp):
signature: tys.PolyFuncType
instantiation: tys.FunctionType
type_args: list[tys.TypeArg]
num_out: int | None = 1

def __init__(
self,
signature: tys.PolyFuncType,
instantiation: tys.FunctionType | None = None,
type_args: Sequence[tys.TypeArg] | None = None,
) -> None:
self.signature = signature
self.instantiation, self.type_args = _fn_instantiation(
signature, instantiation, type_args
)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.LoadFunction:
return sops.LoadFunction(
parent=parent.idx,
func_sig=self.signature.to_serial(),
type_args=ser_it(self.type_args),
signature=self.outer_signature().to_serial(),
)

def outer_signature(self) -> tys.FunctionType:
return tys.FunctionType(input=[], output=[self.instantiation])

def port_kind(self, port: InPort | OutPort) -> tys.Kind:
match port:
case InPort(_, 0):
return tys.FunctionKind(self.signature)
case OutPort(_, 0):
return tys.ValueKind(self.instantiation)
case _:
raise InvalidPort(port)

Check warning on line 726 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L720-L726

Added lines #L720 - L726 were not covered by tests


@dataclass
class NoopDef(DataflowOp, PartialOp):
_type: tys.Type | None = None
num_out: int | None = 1

@property
def type_(self) -> tys.Type:
return _check_complete(self._type)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Noop:
return sops.Noop(parent=parent.idx, ty=self.type_.to_serial_root())

def outer_signature(self) -> tys.FunctionType:
return tys.FunctionType.endo([self.type_])

def set_in_types(self, types: tys.TypeRow) -> None:
(t,) = types
self._type = t


Noop = NoopDef()


@dataclass
class Lift(DataflowOp, PartialOp):
new_extension: tys.ExtensionId
_type_row: tys.TypeRow | None = None
num_out: int | None = 1

@property
def type_row(self) -> tys.TypeRow:
return _check_complete(self._type_row)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Lift:
return sops.Lift(
parent=parent.idx,
new_extension=self.new_extension,
type_row=ser_it(self.type_row),
)

def outer_signature(self) -> tys.FunctionType:
return tys.FunctionType.endo(self.type_row)

def set_in_types(self, types: tys.TypeRow) -> None:
self._type_row = types


@dataclass
class AliasDecl(Op):
name: str
bound: tys.TypeBound
num_out: int | None = 0

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.AliasDecl:
return sops.AliasDecl(
parent=parent.idx,
name=self.name,
bound=self.bound,
)

def port_kind(self, port: InPort | OutPort) -> tys.Kind:
raise InvalidPort(port)

Check warning on line 790 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L790

Added line #L790 was not covered by tests


@dataclass
class AliasDefn(Op):
name: str
definition: tys.Type
num_out: int | None = 0

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.AliasDefn:
return sops.AliasDefn(
parent=parent.idx,
name=self.name,
definition=self.definition.to_serial_root(),
)

def port_kind(self, port: InPort | OutPort) -> tys.Kind:
raise InvalidPort(port)

Check warning on line 807 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L807

Added line #L807 was not covered by tests
6 changes: 5 additions & 1 deletion hugr-py/src/hugr/_tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,11 @@ class FunctionType(Type):
extension_reqs: ExtensionSet = field(default_factory=ExtensionSet)

def to_serial(self) -> stys.FunctionType:
return stys.FunctionType(input=ser_it(self.input), output=ser_it(self.output))
return stys.FunctionType(
input=ser_it(self.input),
output=ser_it(self.output),
extension_reqs=self.extension_reqs,
)

@classmethod
def empty(cls) -> FunctionType:
Expand Down
Loading
Loading