Skip to content

Commit

Permalink
feat: add convert for enum (#2977)
Browse files Browse the repository at this point in the history
add conversions between uint256 and enum types

(other int types were not added to make testing more straightforward)
  • Loading branch information
tserg authored Jul 26, 2022
1 parent 4b3b636 commit 7b8b082
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -622,5 +622,6 @@ All type conversions in Vyper must be made explicitly using the built-in ``conve
* Narrowing conversions (e.g., ``int256 -> int128``) check that the input is in bounds for the output type.
* Converting between bytes and int types results in sign-extension if the output type is signed. For instance, converting ``0xff`` (``bytes1``) to ``int8`` returns ``-1``.
* Converting between bytes and int types which have different sizes follows the rule of going through the closest integer type, first. For instance, ``bytes1 -> int16`` is like ``bytes1 -> int8 -> int16`` (signextend, then widen). ``uint8 -> bytes20`` is like ``uint8 -> uint160 -> bytes20`` (rotate left 12 bytes).
* Enums can be converted to and from ``uint256`` only.

A small Python reference implementation is maintained as part of Vyper's test suite, it can be found `here <https://github.com/vyperlang/vyper/blob/c4c6afd07801a0cc0038cdd4007cc43860c54193/tests/parser/functions/test_convert.py#L318>`_. The motivation and more detailed discussion of the rules can be found `here <https://github.com/vyperlang/vyper/issues/2507>`_.
51 changes: 51 additions & 0 deletions tests/parser/functions/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
DECIMAL_DIVISOR,
SizeLimits,
checksum_encode,
int_bounds,
is_checksum_encoded,
round_towards_zero,
unsigned_to_signed,
Expand Down Expand Up @@ -513,6 +514,56 @@ def test_memory_variable_convert(x: {i_typ}) -> {o_typ}:
assert c4.test_memory_variable_convert(val) == expected_val


@pytest.mark.parametrize("typ", ["uint8", "int128", "int256", "uint256"])
@pytest.mark.parametrize("val", [1, 2, 2 ** 128, 2 ** 256 - 1, 2 ** 256 - 2])
def test_enum_conversion(get_contract_with_gas_estimation, assert_compile_failed, val, typ):
roles = "\n ".join([f"ROLE_{i}" for i in range(256)])
contract = f"""
enum Roles:
{roles}
@external
def foo(a: Roles) -> {typ}:
return convert(a, {typ})
@external
def bar(a: uint256) -> Roles:
return convert(a, Roles)
"""
if typ == "uint256":
c = get_contract_with_gas_estimation(contract)
assert c.foo(val) == val
assert c.bar(val) == val
else:
assert_compile_failed(lambda: get_contract_with_gas_estimation(contract), TypeMismatch)


@pytest.mark.parametrize("typ", ["uint8", "int128", "int256", "uint256"])
@pytest.mark.parametrize("val", [1, 2, 3, 4, 2 ** 128, 2 ** 256 - 1, 2 ** 256 - 2])
def test_enum_conversion_2(
get_contract_with_gas_estimation, assert_compile_failed, assert_tx_failed, val, typ
):
contract = f"""
enum Status:
STARTED
PAUSED
STOPPED
@external
def foo(a: {typ}) -> Status:
return convert(a, Status)
"""
if typ == "uint256":
c = get_contract_with_gas_estimation(contract)
lo, hi = int_bounds(signed=False, bits=3)
if lo <= val <= hi:
assert c.foo(val) == val
else:
assert_tx_failed(lambda: c.foo(val))
else:
assert_compile_failed(lambda: get_contract_with_gas_estimation(contract), TypeMismatch)


# TODO CMC 2022-04-06 I think this test is somewhat unnecessary.
@pytest.mark.parametrize(
"builtin_constant,out_type,out_value",
Expand Down
20 changes: 20 additions & 0 deletions vyper/builtin_functions/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
BaseType,
ByteArrayLike,
ByteArrayType,
EnumType,
StringType,
is_base_type,
is_bytes_m_type,
is_decimal_type,
is_enum_type,
is_integer_type,
)
from vyper.exceptions import (
Expand Down Expand Up @@ -326,6 +328,11 @@ def to_int(expr, arg, out_typ):
elif is_decimal_type(arg.typ):
arg = _fixed_to_int(arg, out_typ)

elif is_enum_type(arg.typ):
if not is_base_type(out_typ, "uint256"):
_FAIL(arg.typ, out_typ, expr)
arg = _int_to_int(arg, out_typ)

elif is_integer_type(arg.typ):
arg = _int_to_int(arg, out_typ)

Expand Down Expand Up @@ -455,6 +462,17 @@ def to_bytes(expr, arg, out_typ):
return IRnode.from_list(arg, typ=out_typ)


@_input_types("int")
def to_enum(expr, arg, out_typ):
if not is_base_type(arg.typ, "uint256"):
_FAIL(arg.typ, out_typ, expr)

if len(out_typ.members) < 256:
arg = int_clamp(arg, bits=len(out_typ.members), signed=False)

return IRnode.from_list(arg, typ=out_typ)


def convert(expr, context):
if len(expr.args) != 2:
raise StructureException("The convert function expects two parameters.", expr)
Expand All @@ -471,6 +489,8 @@ def convert(expr, context):
ret = to_bool(arg_ast, arg, out_typ)
elif is_base_type(out_typ, "address"):
ret = to_address(arg_ast, arg, out_typ)
elif isinstance(out_typ, EnumType):
ret = to_enum(arg_ast, arg, out_typ)
elif is_integer_type(out_typ):
ret = to_int(arg_ast, arg, out_typ)
elif is_bytes_m_type(out_typ):
Expand Down
6 changes: 6 additions & 0 deletions vyper/codegen/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ def parse_decimal_info(typename: str) -> DecimalTypeInfo:
return DecimalTypeInfo(bits=168, decimals=10, is_signed=True)


def is_enum_type(t: "NodeType") -> bool:
return isinstance(t, EnumType)


def _basetype_to_abi_type(t: "BaseType") -> ABIType:
if is_integer_type(t):
info = t._int_info
Expand Down Expand Up @@ -237,6 +241,8 @@ def __repr__(self):
return f"enum {self.name}"

def __eq__(self, other):
if type(self) is not type(other):
return False
return self.name == other.name and self.members == other.members

@property
Expand Down

0 comments on commit 7b8b082

Please sign in to comment.