Skip to content

Commit

Permalink
fix: Loading custom polymorphic function defs as values
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch committed Jun 24, 2024
1 parent b2901d8 commit 8b4ad23
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 17 deletions.
27 changes: 10 additions & 17 deletions guppylang/definition/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from guppylang.definition.common import ParsableDef
from guppylang.definition.value import CompiledCallableDef
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.hugr_builder.hugr import Hugr, Node, OutPortV
from guppylang.hugr_builder.hugr import Hugr, OutPortV
from guppylang.nodes import GlobalCall
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import FunctionType, NoneType, Type, type_to_row
Expand Down Expand Up @@ -145,27 +145,20 @@ def load_with_args(
node,
)
assert len(self.ty.params) == len(type_args)

# Find the module node by walking up the hierarchy
module: Node = dfg.node
while not isinstance(module.op, ops.Module):
if module.parent is None:
raise InternalGuppyError(
"Encountered node that is not contained in a module."
)
module = module.parent
ty = self.ty.instantiate(type_args)

# We create a `FunctionDef` that takes some inputs, compiles a call to the
# function, and returns the results.
def_node = graph.add_def(self.ty, module, self.name)
_, inp_ports = graph.add_input_with_ports(list(self.ty.inputs), def_node)
returns = self.compile_call(
inp_ports, type_args, DFContainer(def_node, {}), graph, globals, node
)
graph.add_output(returns, parent=def_node)
def_node = graph.add_def(ty, dfg.node, self.name)
with graph.parent(def_node):
_, inp_ports = graph.add_input_with_ports(list(ty.inputs))
returns = self.compile_call(
inp_ports, type_args, DFContainer(def_node, {}), graph, globals, node
)
graph.add_output(returns)

# Finally, load the function into the local DFG
return graph.add_load_constant(def_node.out_port(0), dfg.node).out_port(0)
return graph.add_load_function(def_node.out_port(0), [], dfg.node).out_port(0)


class CustomCallChecker(ABC):
Expand Down
19 changes: 19 additions & 0 deletions tests/integration/test_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import pytest

from guppylang.decorator import guppy
from guppylang.definition.custom import CustomCallCompiler
from guppylang.hugr_builder.hugr import OutPortV
from guppylang.module import GuppyModule
from guppylang.prelude.quantum import qubit

Expand Down Expand Up @@ -261,6 +263,23 @@ def main() -> None:
validate(module.compile())


def test_custom_higher_order():
class CustomCompiler(CustomCallCompiler):
def compile(self, args: list[OutPortV]) -> list[OutPortV]:
return args

module = GuppyModule("test")
T = guppy.type_var(module, "T")

@guppy.custom(module, CustomCompiler())
def foo(x: T) -> T: ...

@guppy(module)
def main(x: int) -> int:
f: Callable[[int], int] = foo
return f(x)


@pytest.mark.skip("Not yet supported")
def test_higher_order_value(validate):
module = GuppyModule("test")
Expand Down

0 comments on commit 8b4ad23

Please sign in to comment.