Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[lang]: use folded node for typechecking #4365

Merged
merged 12 commits into from
Nov 23, 2024
16 changes: 16 additions & 0 deletions tests/functional/codegen/features/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,22 @@ def foo():
assert log.topics == [event_id, topic1, topic2, topic3]


valid_list = [
"""
tserg marked this conversation as resolved.
Show resolved Hide resolved
topic: constant(bytes32) = 0x1212121212121210212801291212121212121210121212121212121212121212

@external
def foo():
raw_log([[topic]][0], b'')
"""
]


@pytest.mark.parametrize("code", valid_list)
def test_raw_log_pass(code):
assert compile_code(code) is not None


fail_list = [
(
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ def foo():
def foo():
a: bytes32 = keccak256("ѓtest")
""",
"""
tserg marked this conversation as resolved.
Show resolved Hide resolved
BAR: constant(uint16) = 256

@external
def foo():
a: uint8 = convert(BAR, uint8)
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved
""",
]


Expand Down
15 changes: 15 additions & 0 deletions tests/functional/syntax/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,21 @@ def foo(inp: Bytes[10]) -> Bytes[4]:
def foo() -> Bytes[10]:
return slice(b"badmintonzzz", 1, 10)
""",
"""
tserg marked this conversation as resolved.
Show resolved Hide resolved
@external
def foo():
x: Bytes[32] = slice(msg.data, 0, 31 + 1)
""",
"""
@external
def foo(a: address):
x: Bytes[32] = slice(a.code, 0, 31 + 1)
""",
"""
@external
def foo(inp: Bytes[5], start: uint256) -> Bytes[3]:
return slice(inp, 0, 1 + 1)
""",
]


Expand Down
2 changes: 1 addition & 1 deletion vyper/builtins/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def to_flag(expr, arg, out_typ):
def convert(expr, context):
assert len(expr.args) == 2, "bad typecheck: convert"

arg_ast = expr.args[0]
arg_ast = expr.args[0].reduced()
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved
arg = Expr(arg_ast, context).ir_node
original_arg = arg

Expand Down
7 changes: 4 additions & 3 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def fetch_call_return(self, node):

arg = node.args[0]
start_expr = node.args[1]
length_expr = node.args[2]
length_expr = node.args[2].reduced()
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved

# CMC 2022-03-22 NOTE slight code duplication with semantics/analysis/local
is_adhoc_slice = arg.get("attr") == "code" or (
Expand Down Expand Up @@ -1257,7 +1257,8 @@ def fetch_call_return(self, node):
def infer_arg_types(self, node, expected_return_typ=None):
self._validate_arg_types(node)

if not isinstance(node.args[0], vy_ast.List) or len(node.args[0].elements) > 4:
arg = node.args[0].reduced()
if not isinstance(arg, vy_ast.List) or len(arg.elements) > 4:
raise InvalidType("Expecting a list of 0-4 topics as first argument", node.args[0])

# return a concrete type for `data`
Expand All @@ -1269,7 +1270,7 @@ def infer_arg_types(self, node, expected_return_typ=None):
def build_IR(self, expr, args, kwargs, context):
context.check_is_not_constant(f"use {self._id}", expr)

topics_length = len(expr.args[0].elements)
topics_length = len(expr.args[0].reduced().elements)
topics = args[0].args
topics = [unwrap_location(topic) for topic in topics]

Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _validate_address_code(node: vy_ast.Attribute, value_type: VyperType) -> Non
parent = node.get_ancestor()
if isinstance(parent, vy_ast.Call):
ok_func = isinstance(parent.func, vy_ast.Name) and parent.func.id == "slice"
ok_args = len(parent.args) == 3 and isinstance(parent.args[2], vy_ast.Int)
ok_args = len(parent.args) == 3 and isinstance(parent.args[2].reduced(), vy_ast.Int)
if ok_func and ok_args:
return

Expand All @@ -154,7 +154,7 @@ def _validate_msg_data_attribute(node: vy_ast.Attribute) -> None:
"msg.data is only allowed inside of the slice, len or raw_call functions", node
)
if parent.get("func.id") == "slice":
ok_args = len(parent.args) == 3 and isinstance(parent.args[2], vy_ast.Int)
ok_args = len(parent.args) == 3 and isinstance(parent.args[2].reduced(), vy_ast.Int)
if not ok_args:
raise StructureException(
"slice(msg.data) must use a compile-time constant for length argument", parent
Expand Down
Loading