diff --git a/tests/parser/features/test_assignment.py b/tests/parser/features/test_assignment.py index 0dd63a0d09..29ec820484 100644 --- a/tests/parser/features/test_assignment.py +++ b/tests/parser/features/test_assignment.py @@ -39,7 +39,118 @@ def augmod(x: int128, y: int128) -> int128: print("Passed aug-assignment test") -def test_invalid_assign(assert_compile_failed, get_contract_with_gas_estimation): +@pytest.mark.parametrize( + "typ,in_val,out_val", + [ + ("uint256", 77, 123), + ("uint256[3]", [1, 2, 3], [4, 5, 6]), + ("DynArray[uint256, 3]", [1, 2, 3], [4, 5, 6]), + ("Bytes[5]", b"vyper", b"conda"), + ], +) +def test_internal_assign(get_contract_with_gas_estimation, typ, in_val, out_val): + code = f""" +@internal +def foo(x: {typ}) -> {typ}: + x = {out_val} + return x + +@external +def bar(x: {typ}) -> {typ}: + return self.foo(x) + """ + c = get_contract_with_gas_estimation(code) + + assert c.bar(in_val) == out_val + + +def test_internal_assign_struct(get_contract_with_gas_estimation): + code = """ +enum Bar: + BAD + BAK + BAZ + +struct Foo: + a: uint256 + b: DynArray[Bar, 3] + c: String[5] + +@internal +def foo(x: Foo) -> Foo: + x = Foo({a: 789, b: [Bar.BAZ, Bar.BAK, Bar.BAD], c: \"conda\"}) + return x + +@external +def bar(x: Foo) -> Foo: + return self.foo(x) + """ + c = get_contract_with_gas_estimation(code) + + assert c.bar((123, [1, 2, 4], "vyper")) == (789, [4, 2, 1], "conda") + + +def test_internal_assign_struct_member(get_contract_with_gas_estimation): + code = """ +enum Bar: + BAD + BAK + BAZ + +struct Foo: + a: uint256 + b: DynArray[Bar, 3] + c: String[5] + +@internal +def foo(x: Foo) -> Foo: + x.a = 789 + x.b.pop() + return x + +@external +def bar(x: Foo) -> Foo: + return self.foo(x) + """ + c = get_contract_with_gas_estimation(code) + + assert c.bar((123, [1, 2, 4], "vyper")) == (789, [1, 2], "vyper") + + +def test_internal_augassign(get_contract_with_gas_estimation): + code = """ +@internal +def foo(x: int128) -> int128: + x += 77 + return x + +@external +def bar(x: int128) -> int128: + return self.foo(x) + """ + c = get_contract_with_gas_estimation(code) + + assert c.bar(123) == 200 + + +@pytest.mark.parametrize("typ", ["DynArray[uint256, 3]", "uint256[3]"]) +def test_internal_augassign_arrays(get_contract_with_gas_estimation, typ): + code = f""" +@internal +def foo(x: {typ}) -> {typ}: + x[1] += 77 + return x + +@external +def bar(x: {typ}) -> {typ}: + return self.foo(x) + """ + c = get_contract_with_gas_estimation(code) + + assert c.bar([1, 2, 3]) == [1, 79, 3] + + +def test_invalid_external_assign(assert_compile_failed, get_contract_with_gas_estimation): code = """ @external def foo(x: int128): @@ -48,7 +159,7 @@ def foo(x: int128): assert_compile_failed(lambda: get_contract_with_gas_estimation(code), ImmutableViolation) -def test_invalid_augassign(assert_compile_failed, get_contract_with_gas_estimation): +def test_invalid_external_augassign(assert_compile_failed, get_contract_with_gas_estimation): code = """ @external def foo(x: int128): diff --git a/vyper/codegen/function_definitions/internal_function.py b/vyper/codegen/function_definitions/internal_function.py index 17479c4c07..228191e3ca 100644 --- a/vyper/codegen/function_definitions/internal_function.py +++ b/vyper/codegen/function_definitions/internal_function.py @@ -41,8 +41,8 @@ def generate_ir_for_internal_function( for arg in func_t.arguments: # allocate a variable for every arg, setting mutability - # to False to comply with vyper semantics, function arguments are immutable - context.new_variable(arg.name, arg.typ, is_mutable=False) + # to True to allow internal function arguments to be mutable + context.new_variable(arg.name, arg.typ, is_mutable=True) nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(func_t) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 790cee52d6..c99b582ad3 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -173,10 +173,13 @@ def __init__( self.func = fn_node._metadata["type"] self.annotation_visitor = StatementAnnotationVisitor(fn_node, namespace) self.expr_visitor = _LocalExpressionVisitor() + + # allow internal function params to be mutable + location, is_immutable = ( + (DataLocation.MEMORY, False) if self.func.is_internal else (DataLocation.CALLDATA, True) + ) for arg in self.func.arguments: - namespace[arg.name] = VarInfo( - arg.typ, location=DataLocation.CALLDATA, is_immutable=True - ) + namespace[arg.name] = VarInfo(arg.typ, location=location, is_immutable=is_immutable) for node in fn_node.body: self.visit(node)