-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Generate constructor methods for structs #262
Changes from 3 commits
8b4ad23
7ef3416
e940574
0d10df1
d47b03e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
import textwrap | ||
from collections.abc import Sequence | ||
from dataclasses import dataclass | ||
from functools import cached_property | ||
from typing import Any | ||
|
||
from guppylang.ast_util import AstNode, annotate_location | ||
|
@@ -14,13 +15,19 @@ | |
Definition, | ||
ParsableDef, | ||
) | ||
from guppylang.definition.custom import ( | ||
CustomCallCompiler, | ||
CustomFunctionDef, | ||
DefaultCallChecker, | ||
) | ||
from guppylang.definition.parameter import ParamDef | ||
from guppylang.definition.ty import TypeDef | ||
from guppylang.error import GuppyError, InternalGuppyError | ||
from guppylang.hugr_builder.hugr import OutPortV | ||
from guppylang.tys.arg import Argument | ||
from guppylang.tys.param import Parameter, check_all_args | ||
from guppylang.tys.parsing import type_from_ast | ||
from guppylang.tys.ty import StructType, Type | ||
from guppylang.tys.ty import FunctionType, StructType, Type | ||
|
||
|
||
@dataclass(frozen=True) | ||
|
@@ -186,6 +193,33 @@ def check_instantiate( | |
check_all_args(self.params, args, self.name, loc) | ||
return StructType(args, self) | ||
|
||
@cached_property | ||
def generated_methods(self) -> list[CustomFunctionDef]: | ||
"""Auto-generated methods for this struct.""" | ||
|
||
class ConstructorCompiler(CustomCallCompiler): | ||
"""Compiler for the `__new__` constructor method of a struct.""" | ||
|
||
def compile(self, args: list[OutPortV]) -> list[OutPortV]: | ||
return [self.graph.add_make_tuple(args).out_port(0)] | ||
|
||
constructor_sig = FunctionType( | ||
inputs=[f.ty for f in self.fields], | ||
output=StructType([p.to_bound(i) for i, p in enumerate(self.params)], self), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How does this work, I assume the list is the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, exactly. I could also use keyword args if it would be clearer? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes please! |
||
input_names=[f.name for f in self.fields], | ||
params=self.params, | ||
) | ||
constructor_def = CustomFunctionDef( | ||
id=DefId.fresh(self.id.module), | ||
name="__new__", | ||
defined_at=self.defined_at, | ||
ty=constructor_sig, | ||
call_checker=DefaultCallChecker(), | ||
call_compiler=ConstructorCompiler(), | ||
higher_order_value=True, | ||
) | ||
return [constructor_def] | ||
|
||
|
||
def parse_py_class(cls: type) -> ast.ClassDef: | ||
"""Parses a Python class object into an AST.""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
from guppylang.definition.declaration import RawFunctionDecl | ||
from guppylang.definition.function import RawFunctionDef | ||
from guppylang.definition.parameter import ParamDef | ||
from guppylang.definition.struct import CheckedStructDef | ||
from guppylang.definition.ty import TypeDef | ||
from guppylang.error import GuppyError, pretty_errors | ||
from guppylang.hugr_builder.hugr import Hugr | ||
|
@@ -180,9 +181,18 @@ def compile(self) -> Hugr: | |
) | ||
self._globals = self._globals.update_defs(type_defs) | ||
|
||
# Collect auto-generated methods | ||
generated: dict[DefId, RawDef] = {} | ||
for defn in type_defs.values(): | ||
if isinstance(defn, CheckedStructDef): | ||
for method_def in defn.generated_methods: | ||
generated[method_def.id] = method_def | ||
self._globals.impls.setdefault(defn.id, {}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this wants to happen in the outer loop, so that the impls aren't set to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
https://docs.python.org/3/library/stdtypes.html#dict.setdefault There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, I think this still wants to be in the outer loop, but is ultimately harmless then |
||
self._globals.impls[defn.id][method_def.name] = method_def.id | ||
|
||
# Now, we can check all other definitions | ||
other_defs = self._check_defs( | ||
self._raw_defs, self._imported_globals | self._globals | ||
self._raw_defs | generated, self._imported_globals | self._globals | ||
) | ||
self._globals = self._globals.update_defs(other_defs) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Guppy compilation failed. Error in file $FILE:15 | ||
|
||
13: @guppy(module) | ||
14: def main() -> None: | ||
15: MyStruct(0) | ||
^ | ||
GuppyTypeError: Expected argument of type `(int, int)`, got `int` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from guppylang.decorator import guppy | ||
from guppylang.module import GuppyModule | ||
|
||
|
||
module = GuppyModule("test") | ||
|
||
|
||
@guppy.struct(module) | ||
class MyStruct: | ||
x: tuple[int, int] | ||
|
||
|
||
@guppy(module) | ||
def main() -> None: | ||
MyStruct(0) | ||
|
||
|
||
module.compile() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Guppy compilation failed. Error in file $FILE:19 | ||
|
||
17: @guppy(module) | ||
18: def main() -> None: | ||
19: MyStruct(0, False) | ||
^^^^^ | ||
GuppyTypeError: Expected argument of type `int`, got `bool` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from typing import Generic | ||
|
||
from guppylang.decorator import guppy | ||
from guppylang.module import GuppyModule | ||
|
||
|
||
module = GuppyModule("test") | ||
T = guppy.type_var(module, "T") | ||
|
||
|
||
@guppy.struct(module) | ||
class MyStruct(Generic[T]): | ||
x: T | ||
y: T | ||
|
||
|
||
@guppy(module) | ||
def main() -> None: | ||
MyStruct(0, False) | ||
|
||
|
||
module.compile() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Guppy compilation failed. Error in file $FILE:15 | ||
|
||
13: @guppy(module) | ||
14: def main() -> None: | ||
15: MyStruct() | ||
^^^^^^^^^^ | ||
GuppyTypeError: Not enough arguments passed (expected 1, got 0) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from guppylang.decorator import guppy | ||
from guppylang.module import GuppyModule | ||
|
||
|
||
module = GuppyModule("test") | ||
|
||
|
||
@guppy.struct(module) | ||
class MyStruct: | ||
x: int | ||
|
||
|
||
@guppy(module) | ||
def main() -> None: | ||
MyStruct() | ||
|
||
|
||
module.compile() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Guppy compilation failed. Error in file $FILE:15 | ||
|
||
13: @guppy(module) | ||
14: def main() -> None: | ||
15: MyStruct(1, 2, 3) | ||
^ | ||
GuppyTypeError: Unexpected argument |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from guppylang.decorator import guppy | ||
from guppylang.module import GuppyModule | ||
|
||
|
||
module = GuppyModule("test") | ||
|
||
|
||
@guppy.struct(module) | ||
class MyStruct: | ||
x: int | ||
|
||
|
||
@guppy(module) | ||
def main() -> None: | ||
MyStruct(1, 2, 3) | ||
|
||
|
||
module.compile() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When is this not true?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When the type of the function cannot be expressed in Guppy's type system.
For example, the
len(x)
function is aCustomFunction
that checks ifx
implements__len__
. We can't write down a type forlen
, so we can't use it as a value