Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Generate constructor methods for structs #262

Merged
merged 5 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion guppylang/definition/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,16 @@ def compile_call(


@dataclass(frozen=True)
class CustomFunctionDef(RawCustomFunctionDef, CompiledCallableDef):
class CustomFunctionDef(CompiledCallableDef):
"""A custom function with parsed and checked signature."""

defined_at: AstNode
call_checker: "CustomCallChecker"
call_compiler: "CustomCallCompiler"
ty: FunctionType
higher_order_value: bool
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is this not true?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the type of the function cannot be expressed in Guppy's type system.

For example, the len(x) function is a CustomFunction that checks if x implements __len__. We can't write down a type for len, so we can't use it as a value


description: str = field(default="function", init=False)

def check_call(
self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context
Expand Down Expand Up @@ -163,6 +169,19 @@ def load_with_args(
# can load with empty type args
return graph.add_load_function(def_node.out_port(0), [], dfg.node).out_port(0)

def compile_call(
self,
args: list[OutPortV],
type_args: Inst,
dfg: DFContainer,
graph: Hugr,
globals: CompiledGlobals,
node: AstNode,
) -> list[OutPortV]:
"""Compiles a call to the function."""
self.call_compiler._setup(type_args, dfg, graph, globals, node)
return self.call_compiler.compile(args)


class CustomCallChecker(ABC):
"""Abstract base class for custom function call type checkers."""
Expand Down
36 changes: 35 additions & 1 deletion guppylang/definition/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import textwrap
from collections.abc import Sequence
from dataclasses import dataclass
from functools import cached_property
from typing import Any

from guppylang.ast_util import AstNode, annotate_location
Expand All @@ -14,13 +15,19 @@
Definition,
ParsableDef,
)
from guppylang.definition.custom import (
CustomCallCompiler,
CustomFunctionDef,
DefaultCallChecker,
)
from guppylang.definition.parameter import ParamDef
from guppylang.definition.ty import TypeDef
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.hugr_builder.hugr import OutPortV
from guppylang.tys.arg import Argument
from guppylang.tys.param import Parameter, check_all_args
from guppylang.tys.parsing import type_from_ast
from guppylang.tys.ty import StructType, Type
from guppylang.tys.ty import FunctionType, StructType, Type


@dataclass(frozen=True)
Expand Down Expand Up @@ -186,6 +193,33 @@ def check_instantiate(
check_all_args(self.params, args, self.name, loc)
return StructType(args, self)

@cached_property
def generated_methods(self) -> list[CustomFunctionDef]:
"""Auto-generated methods for this struct."""

class ConstructorCompiler(CustomCallCompiler):
"""Compiler for the `__new__` constructor method of a struct."""

def compile(self, args: list[OutPortV]) -> list[OutPortV]:
return [self.graph.add_make_tuple(args).out_port(0)]

constructor_sig = FunctionType(
inputs=[f.ty for f in self.fields],
output=StructType([p.to_bound(i) for i, p in enumerate(self.params)], self),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this work, I assume the list is the args from ParametrizedTypeBase, then self is the defn from StructType. So in the default constructors for python dataclasses, you provide the fields for all of the parents, outer to inner?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly. I could also use keyword args if it would be clearer?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes please!

input_names=[f.name for f in self.fields],
params=self.params,
)
constructor_def = CustomFunctionDef(
id=DefId.fresh(self.id.module),
name="__new__",
defined_at=self.defined_at,
ty=constructor_sig,
call_checker=DefaultCallChecker(),
call_compiler=ConstructorCompiler(),
higher_order_value=True,
)
return [constructor_def]


def parse_py_class(cls: type) -> ast.ClassDef:
"""Parses a Python class object into an AST."""
Expand Down
12 changes: 11 additions & 1 deletion guppylang/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from guppylang.definition.declaration import RawFunctionDecl
from guppylang.definition.function import RawFunctionDef
from guppylang.definition.parameter import ParamDef
from guppylang.definition.struct import CheckedStructDef
from guppylang.definition.ty import TypeDef
from guppylang.error import GuppyError, pretty_errors
from guppylang.hugr_builder.hugr import Hugr
Expand Down Expand Up @@ -180,9 +181,18 @@ def compile(self) -> Hugr:
)
self._globals = self._globals.update_defs(type_defs)

# Collect auto-generated methods
generated: dict[DefId, RawDef] = {}
for defn in type_defs.values():
if isinstance(defn, CheckedStructDef):
for method_def in defn.generated_methods:
generated[method_def.id] = method_def
self._globals.impls.setdefault(defn.id, {})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this wants to happen in the outer loop, so that the impls aren't set to {} each time a method is added

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setdefault only sets if the key doesn't exist yet

https://docs.python.org/3/library/stdtypes.html#dict.setdefault

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I think this still wants to be in the outer loop, but is ultimately harmless then

self._globals.impls[defn.id][method_def.name] = method_def.id

# Now, we can check all other definitions
other_defs = self._check_defs(
self._raw_defs, self._imported_globals | self._globals
self._raw_defs | generated, self._imported_globals | self._globals
)
self._globals = self._globals.update_defs(other_defs)

Expand Down
7 changes: 7 additions & 0 deletions tests/error/struct_errors/constructor_arg_mismatch.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:15

13: @guppy(module)
14: def main() -> None:
15: MyStruct(0)
^
GuppyTypeError: Expected argument of type `(int, int)`, got `int`
18 changes: 18 additions & 0 deletions tests/error/struct_errors/constructor_arg_mismatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule


module = GuppyModule("test")


@guppy.struct(module)
class MyStruct:
x: tuple[int, int]


@guppy(module)
def main() -> None:
MyStruct(0)


module.compile()
7 changes: 7 additions & 0 deletions tests/error/struct_errors/constructor_arg_mismatch_poly.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:19

17: @guppy(module)
18: def main() -> None:
19: MyStruct(0, False)
^^^^^
GuppyTypeError: Expected argument of type `int`, got `bool`
22 changes: 22 additions & 0 deletions tests/error/struct_errors/constructor_arg_mismatch_poly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Generic

from guppylang.decorator import guppy
from guppylang.module import GuppyModule


module = GuppyModule("test")
T = guppy.type_var(module, "T")


@guppy.struct(module)
class MyStruct(Generic[T]):
x: T
y: T


@guppy(module)
def main() -> None:
MyStruct(0, False)


module.compile()
7 changes: 7 additions & 0 deletions tests/error/struct_errors/constructor_missing_arg.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:15

13: @guppy(module)
14: def main() -> None:
15: MyStruct()
^^^^^^^^^^
GuppyTypeError: Not enough arguments passed (expected 1, got 0)
18 changes: 18 additions & 0 deletions tests/error/struct_errors/constructor_missing_arg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule


module = GuppyModule("test")


@guppy.struct(module)
class MyStruct:
x: int


@guppy(module)
def main() -> None:
MyStruct()


module.compile()
7 changes: 7 additions & 0 deletions tests/error/struct_errors/constructor_too_many_args.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:15

13: @guppy(module)
14: def main() -> None:
15: MyStruct(1, 2, 3)
^
GuppyTypeError: Unexpected argument
18 changes: 18 additions & 0 deletions tests/error/struct_errors/constructor_too_many_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule


module = GuppyModule("test")


@guppy.struct(module)
class MyStruct:
x: int


@guppy(module)
def main() -> None:
MyStruct(1, 2, 3)


module.compile()
41 changes: 35 additions & 6 deletions tests/integration/test_struct.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Generic
from typing import Generic, TYPE_CHECKING

from guppylang.decorator import guppy
from guppylang.module import GuppyModule

if TYPE_CHECKING:
from collections.abc import Callable


def test_basic_defs(validate):
module = GuppyModule("module")
Expand Down Expand Up @@ -30,7 +33,10 @@ class DocstringStruct:
def main(
a: EmptyStruct, b: OneMemberStruct, c: TwoMemberStruct, d: DocstringStruct
) -> None:
pass
EmptyStruct()
OneMemberStruct(42)
TwoMemberStruct((True, 0), 1.0)
DocstringStruct(-1)

validate(module.compile())

Expand All @@ -48,7 +54,7 @@ class StructB:

@guppy(module)
def main(a: StructA, b: StructB) -> None:
pass
StructB(a)

validate(module.compile())

Expand All @@ -66,7 +72,7 @@ class StructB:

@guppy(module)
def main(a: StructA, b: StructB) -> None:
pass
StructA(b)

validate(module.compile())

Expand All @@ -92,7 +98,30 @@ class StructB(Generic[S, T]):
y: StructA[T]

@guppy(module)
def main(a: StructA[StructA[float]], b: StructB[int, bool], c: StructC) -> None:
pass
def main(a: StructA[StructA[float]], b: StructB[bool, int], c: StructC) -> None:
x = StructA((0, False))
y = StructA((0, -5))
StructA((0, x))
StructB(x, a)
StructC(y, StructA((0, [])), StructB(42.0, StructA((4, b))))

validate(module.compile())


def test_higher_order(validate):
module = GuppyModule("module")
T = guppy.type_var(module, "T")

@guppy.struct(module)
class Struct(Generic[T]):
x: T

@guppy(module)
def factory(mk_struct: "Callable[[int], Struct[int]]", x: int) -> Struct[int]:
return mk_struct(x)

@guppy(module)
def main() -> None:
factory(Struct, 42)

validate(module.compile())
Loading