Skip to content

Commit

Permalink
feat: Add result function (#271)
Browse files Browse the repository at this point in the history
Closes #270
  • Loading branch information
mark-koch authored Jun 26, 2024
1 parent f68d0af commit 792fb87
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 2 deletions.
51 changes: 49 additions & 2 deletions guppylang/prelude/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ExprSynthesizer,
check_call,
check_num_args,
check_type_against,
synthesize_call,
)
from guppylang.definition.custom import (
Expand All @@ -21,11 +22,11 @@
from guppylang.error import GuppyError, GuppyTypeError, InternalGuppyError
from guppylang.hugr_builder.hugr import UNDEFINED, OutPortV
from guppylang.nodes import GlobalCall
from guppylang.tys.arg import ConstArg
from guppylang.tys.arg import ConstArg, TypeArg
from guppylang.tys.builtin import bool_type, int_type, list_type
from guppylang.tys.const import ConstValue
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import FunctionType, NumericType, Type, unify
from guppylang.tys.ty import FunctionType, NoneType, NumericType, Type, unify


class ConstInt(BaseModel):
Expand Down Expand Up @@ -270,6 +271,52 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
return self._get_const_len(inst), subst


class ResultChecker(CustomCallChecker):
"""Call checker for the `result` function.
This is a temporary hack until we have implemented the proper results mechanism.
"""

def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
check_num_args(2, len(args), self.node)
[tag, value] = args
if not isinstance(tag, ast.Constant) or not isinstance(tag.value, int):
raise GuppyTypeError("Expected an int literal", tag)
value, ty = ExprSynthesizer(self.ctx).synthesize(value)
if ty.linear:
raise GuppyTypeError(
f"Cannot use value with linear type `{ty}` as a result", value
)
type_args = [
TypeArg(ty),
ConstArg(ConstValue(value=tag.value, ty=NumericType(NumericType.Kind.Nat))),
]
call = GlobalCall(def_id=self.func.id, args=[value], type_args=type_args)
return with_loc(self.node, call), NoneType()

def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
expr, res_ty = self.synthesize(args)
subst, _ = check_type_against(res_ty, ty, self.node)
return expr, subst


class ResultCompiler(CustomCallCompiler):
"""Call compiler for the `result` function.
This is a temporary hack until we have implemented the proper results mechanism.
"""

def compile(self, args: list[OutPortV]) -> list[OutPortV]:
op = ops.CustomOp(
extension="Results",
op_name="Result",
args=[arg.to_hugr() for arg in self.type_args],
parent=UNDEFINED,
)
self.graph.add_node(ops.OpType(op), inputs=args)
return []


class NatTruedivCompiler(CustomCallCompiler):
"""Compiler for the `nat.__truediv__` method."""

Expand Down
7 changes: 7 additions & 0 deletions guppylang/prelude/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
FloatModCompiler,
IntTruedivCompiler,
NatTruedivCompiler,
ResultChecker,
ResultCompiler,
ReversingChecker,
UnsupportedChecker,
float_op,
Expand Down Expand Up @@ -611,6 +613,11 @@ def __getitem__(self: array[T, n], idx: int) -> T: ...
def __len__(self: array[T, n]) -> int: ...


# TODO: This is a temporary hack until we have implemented the proper results mechanism.
@guppy.custom(builtins, ResultCompiler(), ResultChecker(), higher_order_value=False)
def result(tag, value): ...


@guppy.custom(builtins, checker=DunderChecker("__abs__"), higher_order_value=False)
def abs(x): ...

Expand Down
7 changes: 7 additions & 0 deletions tests/error/misc_errors/result_tag_not_int.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:7

5: @compile_guppy
6: def foo(x: int) -> None:
7: result((), x)
^^
GuppyTypeError: Expected an int literal
7 changes: 7 additions & 0 deletions tests/error/misc_errors/result_tag_not_int.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from guppylang.prelude.builtins import result
from tests.util import compile_guppy


@compile_guppy
def foo(x: int) -> None:
result((), x)
7 changes: 7 additions & 0 deletions tests/error/misc_errors/result_tag_not_static.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:7

5: @compile_guppy
6: def foo(x: int, y: bool) -> None:
7: result(x, y)
^
GuppyTypeError: Expected an int literal
7 changes: 7 additions & 0 deletions tests/error/misc_errors/result_tag_not_static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from guppylang.prelude.builtins import result
from tests.util import compile_guppy


@compile_guppy
def foo(x: int, y: bool) -> None:
result(x, y)
7 changes: 7 additions & 0 deletions tests/error/misc_errors/result_value_linear.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:14

12: @guppy(module)
13: def foo(q: qubit) -> None:
14: result(0, q)
^
GuppyTypeError: Cannot use value with linear type `qubit` as a result
17 changes: 17 additions & 0 deletions tests/error/misc_errors/result_value_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import guppylang.prelude.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import result
from guppylang.prelude.quantum import qubit


module = GuppyModule("test")
module.load(quantum)


@guppy(module)
def foo(q: qubit) -> None:
result(0, q)


module.compile()
50 changes: 50 additions & 0 deletions tests/integration/test_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from hugr.serialization import ops

from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import result
from tests.util import compile_guppy


def test_single(validate):
@compile_guppy
def main(x: int) -> None:
result(0, x)

validate(main)


def test_value(validate):
@compile_guppy
def main(x: int) -> None:
return result(0, x)

validate(main)


def test_nested(validate):
@compile_guppy
def main(x: int, y: float, z: bool) -> None:
result(42, (x, (y, z)))

validate(main)


def test_multi(validate):
@compile_guppy
def main(x: int, y: float, z: bool) -> None:
result(0, x)
result(1, y)
result(2, z)

validate(main)


def test_same_tag(validate):
@compile_guppy
def main(x: int, y: float, z: bool) -> None:
result(0, x)
result(0, y)
result(0, z)

validate(main)

0 comments on commit 792fb87

Please sign in to comment.