Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: let params of internal functions be mutable #3473

Merged
merged 7 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 113 additions & 2 deletions tests/parser/features/test_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,118 @@ def augmod(x: int128, y: int128) -> int128:
print("Passed aug-assignment test")


def test_invalid_assign(assert_compile_failed, get_contract_with_gas_estimation):
@pytest.mark.parametrize(
"typ,in_val,out_val",
[
("uint256", 77, 123),
("uint256[3]", [1, 2, 3], [4, 5, 6]),
("DynArray[uint256, 3]", [1, 2, 3], [4, 5, 6]),
("Bytes[5]", b"vyper", b"conda"),
],
)
def test_internal_assign(get_contract_with_gas_estimation, typ, in_val, out_val):
code = f"""
@internal
def foo(x: {typ}) -> {typ}:
x = {out_val}
return x

@external
def bar(x: {typ}) -> {typ}:
return self.foo(x)
"""
c = get_contract_with_gas_estimation(code)

assert c.bar(in_val) == out_val


def test_internal_assign_struct(get_contract_with_gas_estimation):
code = """
enum Bar:
BAD
BAK
BAZ

struct Foo:
a: uint256
b: DynArray[Bar, 3]
c: String[5]

@internal
def foo(x: Foo) -> Foo:
x = Foo({a: 789, b: [Bar.BAZ, Bar.BAK, Bar.BAD], c: \"conda\"})
return x

@external
def bar(x: Foo) -> Foo:
return self.foo(x)
"""
c = get_contract_with_gas_estimation(code)

assert c.bar((123, [1, 2, 4], "vyper")) == (789, [4, 2, 1], "conda")


def test_internal_assign_struct_member(get_contract_with_gas_estimation):
code = """
enum Bar:
BAD
BAK
BAZ

struct Foo:
a: uint256
b: DynArray[Bar, 3]
c: String[5]

@internal
def foo(x: Foo) -> Foo:
x.a = 789
x.b.pop()
return x

@external
def bar(x: Foo) -> Foo:
return self.foo(x)
"""
c = get_contract_with_gas_estimation(code)

assert c.bar((123, [1, 2, 4], "vyper")) == (789, [1, 2], "vyper")


def test_internal_augassign(get_contract_with_gas_estimation):
code = """
@internal
def foo(x: int128) -> int128:
x += 77
return x

@external
def bar(x: int128) -> int128:
return self.foo(x)
"""
c = get_contract_with_gas_estimation(code)

assert c.bar(123) == 200


@pytest.mark.parametrize("typ", ["DynArray[uint256, 3]", "uint256[3]"])
def test_internal_augassign_arrays(get_contract_with_gas_estimation, typ):
code = f"""
@internal
def foo(x: {typ}) -> {typ}:
x[1] += 77
return x

@external
def bar(x: {typ}) -> {typ}:
return self.foo(x)
"""
c = get_contract_with_gas_estimation(code)

assert c.bar([1, 2, 3]) == [1, 79, 3]


def test_invalid_external_assign(assert_compile_failed, get_contract_with_gas_estimation):
code = """
@external
def foo(x: int128):
Expand All @@ -48,7 +159,7 @@ def foo(x: int128):
assert_compile_failed(lambda: get_contract_with_gas_estimation(code), ImmutableViolation)


def test_invalid_augassign(assert_compile_failed, get_contract_with_gas_estimation):
def test_invalid_external_augassign(assert_compile_failed, get_contract_with_gas_estimation):
code = """
@external
def foo(x: int128):
Expand Down
4 changes: 2 additions & 2 deletions vyper/codegen/function_definitions/internal_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def generate_ir_for_internal_function(

for arg in func_t.arguments:
# allocate a variable for every arg, setting mutability
# to False to comply with vyper semantics, function arguments are immutable
context.new_variable(arg.name, arg.typ, is_mutable=False)
# to True to allow internal function arguments to be mutable
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved
context.new_variable(arg.name, arg.typ, is_mutable=True)

nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(func_t)

Expand Down
9 changes: 6 additions & 3 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,13 @@ def __init__(
self.func = fn_node._metadata["type"]
self.annotation_visitor = StatementAnnotationVisitor(fn_node, namespace)
self.expr_visitor = _LocalExpressionVisitor()

# allow internal function params to be mutable
location, is_immutable = (
(DataLocation.MEMORY, False) if self.func.is_internal else (DataLocation.CALLDATA, True)
)
for arg in self.func.arguments:
namespace[arg.name] = VarInfo(
arg.typ, location=DataLocation.CALLDATA, is_immutable=True
)
namespace[arg.name] = VarInfo(arg.typ, location=location, is_immutable=is_immutable)

for node in fn_node.body:
self.visit(node)
Expand Down