Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: let params of internal functions be mutable #3473

Merged
merged 7 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions tests/parser/features/test_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,39 @@ def augmod(x: int128, y: int128) -> int128:
print("Passed aug-assignment test")


def test_invalid_assign(assert_compile_failed, get_contract_with_gas_estimation):
def test_internal_assign(get_contract_with_gas_estimation):
code = """
@internal
def foo(x: int128) -> int128:
x = 77
return x

@external
def bar(x: int128) -> int128:
return self.foo(x)
"""
c = get_contract_with_gas_estimation(code)

assert c.bar(123) == 77


def test_internal_augassign(get_contract_with_gas_estimation):
code = """
@internal
def foo(x: int128) -> int128:
x += 77
return x

@external
def bar(x: int128) -> int128:
return self.foo(x)
"""
c = get_contract_with_gas_estimation(code)

assert c.bar(123) == 200


def test_invalid_external_assign(assert_compile_failed, get_contract_with_gas_estimation):
code = """
@external
def foo(x: int128):
Expand All @@ -48,7 +80,7 @@ def foo(x: int128):
assert_compile_failed(lambda: get_contract_with_gas_estimation(code), ImmutableViolation)


def test_invalid_augassign(assert_compile_failed, get_contract_with_gas_estimation):
def test_invalid_external_augassign(assert_compile_failed, get_contract_with_gas_estimation):
code = """
@external
def foo(x: int128):
Expand Down
4 changes: 2 additions & 2 deletions vyper/codegen/function_definitions/internal_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def generate_ir_for_internal_function(

for arg in func_t.arguments:
# allocate a variable for every arg, setting mutability
# to False to comply with vyper semantics, function arguments are immutable
context.new_variable(arg.name, arg.typ, is_mutable=False)
# to True to allow internal function arguments to be mutable
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved
context.new_variable(arg.name, arg.typ, is_mutable=True)

nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(func_t)

Expand Down
6 changes: 5 additions & 1 deletion vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,13 @@ def __init__(
self.func = fn_node._metadata["type"]
self.annotation_visitor = StatementAnnotationVisitor(fn_node, namespace)
self.expr_visitor = _LocalExpressionVisitor()

# allow internal function params to be mutable
location = DataLocation.MEMORY if self.func.is_internal else DataLocation.CALLDATA
is_immutable = False if self.func.is_internal else True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't really like repeating branches in general. here i think we can get rid of it with either squashing all into a single ternary like so

Suggested change
# allow internal function params to be mutable
location = DataLocation.MEMORY if self.func.is_internal else DataLocation.CALLDATA
is_immutable = False if self.func.is_internal else True
# allow internal function params to be mutable
location, is_immutable = (DataLocation.MEMORY, False) if self.func.is_internal else (DataLocation.CALLDATA, True)

or two assignments per branch

Suggested change
# allow internal function params to be mutable
location = DataLocation.MEMORY if self.func.is_internal else DataLocation.CALLDATA
is_immutable = False if self.func.is_internal else True
if self.func.is_internal:
location = DataLocation.MEMORY
is_immutable = False
else:
location = DataLocation.CALLDATA
is_immutable = True

for arg in self.func.arguments:
namespace[arg.name] = VarInfo(
arg.typ, location=DataLocation.CALLDATA, is_immutable=True
arg.typ, location=location, is_immutable=is_immutable
)

for node in fn_node.body:
Expand Down