Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add result function #271

Merged
merged 2 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions guppylang/prelude/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ExprSynthesizer,
check_call,
check_num_args,
check_type_against,
synthesize_call,
)
from guppylang.definition.custom import (
Expand All @@ -21,11 +22,11 @@
from guppylang.error import GuppyError, GuppyTypeError, InternalGuppyError
from guppylang.hugr_builder.hugr import UNDEFINED, OutPortV
from guppylang.nodes import GlobalCall
from guppylang.tys.arg import ConstArg
from guppylang.tys.arg import ConstArg, TypeArg
from guppylang.tys.builtin import bool_type, int_type, list_type
from guppylang.tys.const import ConstValue
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import FunctionType, NumericType, Type, unify
from guppylang.tys.ty import FunctionType, NoneType, NumericType, Type, unify


class ConstInt(BaseModel):
Expand Down Expand Up @@ -270,6 +271,52 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
return self._get_const_len(inst), subst


class ResultChecker(CustomCallChecker):
"""Call checker for the `result` function.

This is a temporary hack until we have implemented the proper results mechanism.
"""

def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
check_num_args(2, len(args), self.node)
[tag, value] = args
if not isinstance(tag, ast.Constant) or not isinstance(tag.value, int):
raise GuppyTypeError("Expected an int literal", tag)
value, ty = ExprSynthesizer(self.ctx).synthesize(value)
if ty.linear:
raise GuppyTypeError(
f"Cannot use value with linear type `{ty}` as a result", value
)
type_args = [
TypeArg(ty),
ConstArg(ConstValue(value=tag.value, ty=NumericType(NumericType.Kind.Nat))),
]
call = GlobalCall(def_id=self.func.id, args=[value], type_args=type_args)
return with_loc(self.node, call), NoneType()

def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
expr, res_ty = self.synthesize(args)
subst, _ = check_type_against(res_ty, ty, self.node)
return expr, subst


class ResultCompiler(CustomCallCompiler):
"""Call compiler for the `result` function.

This is a temporary hack until we have implemented the proper results mechanism.
"""

def compile(self, args: list[OutPortV]) -> list[OutPortV]:
op = ops.CustomOp(
extension="Results",
op_name="Result",
args=[arg.to_hugr() for arg in self.type_args],
parent=UNDEFINED,
)
self.graph.add_node(ops.OpType(op), inputs=args)
return []


class NatTruedivCompiler(CustomCallCompiler):
"""Compiler for the `nat.__truediv__` method."""

Expand Down
7 changes: 7 additions & 0 deletions guppylang/prelude/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
FloatModCompiler,
IntTruedivCompiler,
NatTruedivCompiler,
ResultChecker,
ResultCompiler,
ReversingChecker,
UnsupportedChecker,
float_op,
Expand Down Expand Up @@ -611,6 +613,11 @@ def __getitem__(self: array[T, n], idx: int) -> T: ...
def __len__(self: array[T, n]) -> int: ...


# TODO: This is a temporary hack until we have implemented the proper results mechanism.
@guppy.custom(builtins, ResultCompiler(), ResultChecker(), higher_order_value=False)
def result(tag, value): ...


@guppy.custom(builtins, checker=DunderChecker("__abs__"), higher_order_value=False)
def abs(x): ...

Expand Down
7 changes: 7 additions & 0 deletions tests/error/misc_errors/result_tag_not_int.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:7

5: @compile_guppy
6: def foo(x: int) -> None:
7: result((), x)
^^
GuppyTypeError: Expected an int literal
7 changes: 7 additions & 0 deletions tests/error/misc_errors/result_tag_not_int.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from guppylang.prelude.builtins import result
from tests.util import compile_guppy


@compile_guppy
def foo(x: int) -> None:
result((), x)
7 changes: 7 additions & 0 deletions tests/error/misc_errors/result_tag_not_static.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:7

5: @compile_guppy
6: def foo(x: int, y: bool) -> None:
7: result(x, y)
^
GuppyTypeError: Expected an int literal
7 changes: 7 additions & 0 deletions tests/error/misc_errors/result_tag_not_static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from guppylang.prelude.builtins import result
from tests.util import compile_guppy


@compile_guppy
def foo(x: int, y: bool) -> None:
result(x, y)
7 changes: 7 additions & 0 deletions tests/error/misc_errors/result_value_linear.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:14

12: @guppy(module)
13: def foo(q: qubit) -> None:
14: result(0, q)
^
GuppyTypeError: Cannot use value with linear type `qubit` as a result
17 changes: 17 additions & 0 deletions tests/error/misc_errors/result_value_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import guppylang.prelude.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import result
from guppylang.prelude.quantum import qubit


module = GuppyModule("test")
module.load(quantum)


@guppy(module)
def foo(q: qubit) -> None:
result(0, q)


module.compile()
50 changes: 50 additions & 0 deletions tests/integration/test_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from hugr.serialization import ops

from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import result
from tests.util import compile_guppy


def test_single(validate):
@compile_guppy
def main(x: int) -> None:
result(0, x)

validate(main)


def test_value(validate):
@compile_guppy
def main(x: int) -> None:
return result(0, x)

validate(main)


def test_nested(validate):
@compile_guppy
def main(x: int, y: float, z: bool) -> None:
result(42, (x, (y, z)))

validate(main)


def test_multi(validate):
@compile_guppy
def main(x: int, y: float, z: bool) -> None:
result(0, x)
result(1, y)
result(2, z)

validate(main)


def test_same_tag(validate):
@compile_guppy
def main(x: int, y: float, z: bool) -> None:
result(0, x)
result(0, y)
result(0, z)

validate(main)
Loading