Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Loading custom polymorphic function defs as values #260

Merged
merged 2 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 15 additions & 19 deletions guppylang/definition/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from guppylang.definition.common import ParsableDef
from guppylang.definition.value import CompiledCallableDef
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.hugr_builder.hugr import Hugr, Node, OutPortV
from guppylang.hugr_builder.hugr import Hugr, OutPortV
from guppylang.nodes import GlobalCall
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import FunctionType, NoneType, Type, type_to_row
Expand Down Expand Up @@ -146,26 +146,22 @@ def load_with_args(
)
assert len(self.ty.params) == len(type_args)

# Find the module node by walking up the hierarchy
module: Node = dfg.node
Comment on lines -149 to -150
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can put the FunctionDef anywhere, no need to walk up to the module

while not isinstance(module.op, ops.Module):
if module.parent is None:
raise InternalGuppyError(
"Encountered node that is not contained in a module."
)
module = module.parent

# We create a `FunctionDef` that takes some inputs, compiles a call to the
# function, and returns the results.
def_node = graph.add_def(self.ty, module, self.name)
_, inp_ports = graph.add_input_with_ports(list(self.ty.inputs), def_node)
returns = self.compile_call(
inp_ports, type_args, DFContainer(def_node, {}), graph, globals, node
)
graph.add_output(returns, parent=def_node)
# function, and returns the results. If the function signature is polymorphic,
# we explicitly monomorphise here and invoke the call compiler with the
# inferred type args.
fun_ty = self.ty.instantiate(type_args)
def_node = graph.add_def(fun_ty, dfg.node, self.name)
with graph.parent(def_node):
_, inp_ports = graph.add_input_with_ports(list(fun_ty.inputs))
returns = self.compile_call(
inp_ports, type_args, DFContainer(def_node, {}), graph, globals, node
)
graph.add_output(returns)

# Finally, load the function into the local DFG
return graph.add_load_constant(def_node.out_port(0), dfg.node).out_port(0)
# Finally, load the function into the local DFG. We already monomorphised, so we
# can load with empty type args
return graph.add_load_function(def_node.out_port(0), [], dfg.node).out_port(0)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already monomorphised, so we can load with empty type args

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment should be included! Maybe lower the ty = ... line below you're We create a .. comment, and add "instantiate the monomorphised type of the function" to that comment



class CustomCallChecker(ABC):
Expand Down
19 changes: 19 additions & 0 deletions tests/integration/test_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import pytest

from guppylang.decorator import guppy
from guppylang.definition.custom import CustomCallCompiler
from guppylang.hugr_builder.hugr import OutPortV
from guppylang.module import GuppyModule
from guppylang.prelude.quantum import qubit

Expand Down Expand Up @@ -261,6 +263,23 @@ def main() -> None:
validate(module.compile())


def test_custom_higher_order():
class CustomCompiler(CustomCallCompiler):
def compile(self, args: list[OutPortV]) -> list[OutPortV]:
return args

module = GuppyModule("test")
T = guppy.type_var(module, "T")

@guppy.custom(module, CustomCompiler())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are users expected to write a CustomCompiler? Needs much better docs and examples, and probably more helper functions. Happy to approve without, but recommend you add an issue to flesh this out if you haven't already.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment custom compilers are an escape hatch to build Hugr stuff that can't be directly expressed in Guppy. Not sure how much we want to expose that to users in the future?

I created issue #269 for discussion

def foo(x: T) -> T: ...

@guppy(module)
def main(x: int) -> int:
f: Callable[[int], int] = foo
return f(x)


@pytest.mark.skip("Not yet supported")
def test_higher_order_value(validate):
module = GuppyModule("test")
Expand Down
Loading