diff --git a/tests/parser/exceptions/test_invalid_type_exception.py b/tests/parser/exceptions/test_invalid_type_exception.py index 8cf7f9d6ff..450e6f4ae7 100644 --- a/tests/parser/exceptions/test_invalid_type_exception.py +++ b/tests/parser/exceptions/test_invalid_type_exception.py @@ -1,7 +1,5 @@ import pytest -from pytest import raises -from vyper import compiler from vyper.exceptions import InvalidType, UnknownType fail_list = [ @@ -25,9 +23,8 @@ @pytest.mark.parametrize("bad_code", fail_list) -def test_unknown_type_exception(bad_code): - with raises(UnknownType): - compiler.compile_code(bad_code) +def test_unknown_type_exception(bad_code, get_contract, assert_compile_failed): + assert_compile_failed(lambda: get_contract(bad_code), UnknownType) invalid_list = [ @@ -68,10 +65,13 @@ def mint(_to: address, _value: uint256): """ b: HashMap[(int128, decimal), int128] """, + # Address literal must be checksummed + """ +a: constant(address) = 0x3cd751e6b0078be393132286c442345e5dc49699 + """, ] @pytest.mark.parametrize("bad_code", invalid_list) -def test_invalid_type_exception(bad_code): - with raises(InvalidType): - compiler.compile_code(bad_code) +def test_invalid_type_exception(bad_code, get_contract, assert_compile_failed): + assert_compile_failed(lambda: get_contract(bad_code), InvalidType) diff --git a/vyper/semantics/validation/utils.py b/vyper/semantics/validation/utils.py index b140dcb1de..fe4c027db9 100644 --- a/vyper/semantics/validation/utils.py +++ b/vyper/semantics/validation/utils.py @@ -25,8 +25,11 @@ DynamicArrayDefinition, TupleDefinition, ) +from vyper.semantics.types.value.address import AddressDefinition from vyper.semantics.types.value.boolean import BoolDefinition +from vyper.semantics.types.value.bytes_fixed import Bytes20Definition # type: ignore from vyper.semantics.validation.levenshtein_utils import get_levenshtein_error_suggestions +from vyper.utils import checksum_encode def _validate_op(node, types_list, validation_fn_name): @@ -456,9 +459,16 @@ def validate_expected_type(node, expected_type): types_str = sorted(str(i) for i in given_types) given_str = f"{', '.join(types_str[:1])} or {types_str[-1]}" + suggestion_str = "" + if isinstance(expected_type[0], AddressDefinition) and isinstance( + given_types[0], Bytes20Definition + ): + suggestion_str = f" Did you mean {checksum_encode(node.value)}?" + # CMC 2022-02-14 maybe TypeMismatch would make more sense here raise InvalidType( - f"Expected {expected_str} but literal can only be cast as {given_str}", node + f"Expected {expected_str} but literal can only be cast as {given_str}.{suggestion_str}", + node, )