-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Loading custom polymorphic function defs as values #260
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,7 @@ | |
from guppylang.definition.common import ParsableDef | ||
from guppylang.definition.value import CompiledCallableDef | ||
from guppylang.error import GuppyError, InternalGuppyError | ||
from guppylang.hugr_builder.hugr import Hugr, Node, OutPortV | ||
from guppylang.hugr_builder.hugr import Hugr, OutPortV | ||
from guppylang.nodes import GlobalCall | ||
from guppylang.tys.subst import Inst, Subst | ||
from guppylang.tys.ty import FunctionType, NoneType, Type, type_to_row | ||
|
@@ -146,26 +146,22 @@ def load_with_args( | |
) | ||
assert len(self.ty.params) == len(type_args) | ||
|
||
# Find the module node by walking up the hierarchy | ||
module: Node = dfg.node | ||
while not isinstance(module.op, ops.Module): | ||
if module.parent is None: | ||
raise InternalGuppyError( | ||
"Encountered node that is not contained in a module." | ||
) | ||
module = module.parent | ||
|
||
# We create a `FunctionDef` that takes some inputs, compiles a call to the | ||
# function, and returns the results. | ||
def_node = graph.add_def(self.ty, module, self.name) | ||
_, inp_ports = graph.add_input_with_ports(list(self.ty.inputs), def_node) | ||
returns = self.compile_call( | ||
inp_ports, type_args, DFContainer(def_node, {}), graph, globals, node | ||
) | ||
graph.add_output(returns, parent=def_node) | ||
# function, and returns the results. If the function signature is polymorphic, | ||
# we explicitly monomorphise here and invoke the call compiler with the | ||
# inferred type args. | ||
fun_ty = self.ty.instantiate(type_args) | ||
def_node = graph.add_def(fun_ty, dfg.node, self.name) | ||
with graph.parent(def_node): | ||
_, inp_ports = graph.add_input_with_ports(list(fun_ty.inputs)) | ||
returns = self.compile_call( | ||
inp_ports, type_args, DFContainer(def_node, {}), graph, globals, node | ||
) | ||
graph.add_output(returns) | ||
|
||
# Finally, load the function into the local DFG | ||
return graph.add_load_constant(def_node.out_port(0), dfg.node).out_port(0) | ||
# Finally, load the function into the local DFG. We already monomorphised, so we | ||
# can load with empty type args | ||
return graph.add_load_function(def_node.out_port(0), [], dfg.node).out_port(0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We already monomorphised, so we can load with empty type args There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this comment should be included! Maybe lower the |
||
|
||
|
||
class CustomCallChecker(ABC): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,8 @@ | |
import pytest | ||
|
||
from guppylang.decorator import guppy | ||
from guppylang.definition.custom import CustomCallCompiler | ||
from guppylang.hugr_builder.hugr import OutPortV | ||
from guppylang.module import GuppyModule | ||
from guppylang.prelude.quantum import qubit | ||
|
||
|
@@ -261,6 +263,23 @@ def main() -> None: | |
validate(module.compile()) | ||
|
||
|
||
def test_custom_higher_order(): | ||
class CustomCompiler(CustomCallCompiler): | ||
def compile(self, args: list[OutPortV]) -> list[OutPortV]: | ||
return args | ||
|
||
module = GuppyModule("test") | ||
T = guppy.type_var(module, "T") | ||
|
||
@guppy.custom(module, CustomCompiler()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are users expected to write a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At the moment custom compilers are an escape hatch to build Hugr stuff that can't be directly expressed in Guppy. Not sure how much we want to expose that to users in the future? I created issue #269 for discussion |
||
def foo(x: T) -> T: ... | ||
|
||
@guppy(module) | ||
def main(x: int) -> int: | ||
f: Callable[[int], int] = foo | ||
return f(x) | ||
|
||
|
||
@pytest.mark.skip("Not yet supported") | ||
def test_higher_order_value(validate): | ||
module = GuppyModule("test") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can put the FunctionDef anywhere, no need to walk up to the module