Skip to content

Commit

Permalink
fix: constructor context for internal functions (vyperlang#3388)
Browse files Browse the repository at this point in the history
this commit fixes two related issues with initcode generation:

- nested internal functions called from the constructor would cause a
  compiler panic
- internal functions called from the constructor would not read/write
  from the correct immutables space

the relevant examples reproducing each issue are in the tests. this
commit fixes the issue by

- not trying to traverse the call graph to figure out which internal
  functions to include in the initcode. instead, all internal functions
  are included, and we rely on the dead code eliminator to remove unused
  functions
- adding a "constructor" flag to the codegen, so we can distinguish
  between internal calls which are being generated to include in
  initcode or runtime code.
  • Loading branch information
charles-cooper authored May 15, 2023
1 parent 1c8349e commit c202c4e
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 19 deletions.
49 changes: 49 additions & 0 deletions tests/compiler/asm/test_asm_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from vyper.compiler.phases import CompilerData


def test_dead_code_eliminator():
code = """
s: uint256
@internal
def foo():
self.s = 1
@internal
def qux():
self.s = 2
@external
def bar():
self.foo()
@external
def __init__():
self.qux()
"""

c = CompilerData(code, no_optimize=True)
initcode_asm = [i for i in c.assembly if not isinstance(i, list)]
runtime_asm = c.assembly_runtime

foo_label = "_sym_internal_foo___"
qux_label = "_sym_internal_qux___"

# all the labels should be in all the unoptimized asms
for s in (foo_label, qux_label):
assert s in initcode_asm
assert s in runtime_asm

c = CompilerData(code, no_optimize=False)
initcode_asm = [i for i in c.assembly if not isinstance(i, list)]
runtime_asm = c.assembly_runtime

# qux should not be in runtime code
for instr in runtime_asm:
if isinstance(instr, str):
assert not instr.startswith(qux_label), instr

# foo should not be in initcode asm
for instr in initcode_asm:
if isinstance(instr, str):
assert not instr.startswith(foo_label), instr
8 changes: 4 additions & 4 deletions tests/functional/semantics/analysis/test_for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,30 +108,30 @@ def main():
for j in range(3):
x: uint256 = j
y: uint16 = j
""", # issue 3212
""", # GH issue 3212
"""
@external
def foo():
for i in [1]:
a:uint256 = i
b:uint16 = i
""", # issue 3374
""", # GH issue 3374
"""
@external
def foo():
for i in [1]:
for j in [1]:
a:uint256 = i
b:uint16 = i
""", # issue 3374
""", # GH issue 3374
"""
@external
def foo():
for i in [1,2,3]:
for j in [1,2,3]:
b:uint256 = j + i
c:uint16 = i
""", # issue 3374
""", # GH issue 3374
]


Expand Down
2 changes: 1 addition & 1 deletion tests/parser/features/test_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def test_3034_verbatim(get_contract):
# test issue #3034 exactly
# test GH issue 3034 exactly
code = """
@view
@external
Expand Down
87 changes: 87 additions & 0 deletions tests/parser/features/test_immutable.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,90 @@ def get_immutable() -> uint256:

c = get_contract(code, n)
assert c.get_immutable() == n + 2


# GH issue 3292
def test_internal_functions_called_by_ctor_location(get_contract):
code = """
d: uint256
x: immutable(uint256)
@external
def __init__():
self.d = 1
x = 2
self.a()
@external
def test() -> uint256:
return self.d
@internal
def a():
self.d = x
"""
c = get_contract(code)
assert c.test() == 2


# GH issue 3292, extended to nested internal functions
def test_nested_internal_function_immutables(get_contract):
code = """
d: public(uint256)
x: public(immutable(uint256))
@external
def __init__():
self.d = 1
x = 2
self.a()
@internal
def a():
self.b()
@internal
def b():
self.d = x
"""
c = get_contract(code)
assert c.x() == 2
assert c.d() == 2


# GH issue 3292, test immutable read from both ctor and runtime
def test_immutable_read_ctor_and_runtime(get_contract):
code = """
d: public(uint256)
x: public(immutable(uint256))
@external
def __init__():
self.d = 1
x = 2
self.a()
@internal
def a():
self.d = x
@external
def thrash():
self.d += 5
@external
def fix():
self.a()
"""
c = get_contract(code)
assert c.x() == 2
assert c.d() == 2

c.thrash(transact={})

assert c.x() == 2
assert c.d() == 2 + 5

c.fix(transact={})
assert c.x() == 2
assert c.d() == 2
26 changes: 26 additions & 0 deletions tests/parser/features/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,29 @@ def baz() -> uint8:

n = 256
assert_compile_failed(lambda: get_contract(code, n))


# GH issue 3206
def test_nested_internal_call_from_ctor(get_contract):
code = """
x: uint256
@external
def __init__():
self.a()
@internal
def a():
self.x += 1
self.b()
@internal
def b():
self.x += 2
@external
def test() -> uint256:
return self.x
"""
c = get_contract(code)
assert c.test() == 3
4 changes: 4 additions & 0 deletions vyper/codegen/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
forvars=None,
constancy=Constancy.Mutable,
sig=None,
is_ctor_context=False,
):
# In-memory variables, in the form (name, memory location, type)
self.vars = vars_ or {}
Expand Down Expand Up @@ -92,6 +93,9 @@ def __init__(
self._internal_var_iter = 0
self._scope_id_iter = 0

# either the constructor, or called from the constructor
self.is_ctor_context = is_ctor_context

def is_constant(self):
return self.constancy is Constancy.Constant or self.in_assertion or self.in_range_expr

Expand Down
2 changes: 1 addition & 1 deletion vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def parse_Name(self):

ofst = varinfo.position.offset

if self.context.sig.is_init_func:
if self.context.is_ctor_context:
mutable = True
location = IMMUTABLES
else:
Expand Down
10 changes: 9 additions & 1 deletion vyper/codegen/function_definitions/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def generate_ir_for_function(
sigs: Dict[str, Dict[str, FunctionSignature]], # all signatures in all namespaces
global_ctx: GlobalContext,
skip_nonpayable_check: bool,
is_ctor_context: bool = False,
) -> IRnode:
"""
Parse a function and produce IR code for the function, includes:
Expand Down Expand Up @@ -51,6 +52,7 @@ def generate_ir_for_function(
memory_allocator=memory_allocator,
constancy=Constancy.Constant if sig.mutability in ("view", "pure") else Constancy.Mutable,
sig=sig,
is_ctor_context=is_ctor_context,
)

if sig.internal:
Expand All @@ -65,13 +67,19 @@ def generate_ir_for_function(

frame_size = context.memory_allocator.size_of_mem - MemoryPositions.RESERVED_MEMORY

sig.set_frame_info(FrameInfo(allocate_start, frame_size, context.vars))
frame_info = FrameInfo(allocate_start, frame_size, context.vars)

if sig.frame_info is None:
sig.set_frame_info(frame_info)
else:
assert frame_info == sig.frame_info

if not sig.internal:
# adjust gas estimate to include cost of mem expansion
# frame_size of external function includes all private functions called
# (note: internal functions do not need to adjust gas estimate since
# it is already accounted for by the caller.)
assert sig.frame_info is not None # mypy hint
o.add_gas_estimate += calc_mem_gas(sig.frame_info.mem_used)

sig.gas_estimate = o.gas
Expand Down
34 changes: 22 additions & 12 deletions vyper/codegen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,21 @@ def _runtime_ir(runtime_functions, all_sigs, global_ctx):

# create a map of the IR functions since they might live in both
# runtime and deploy code (if init function calls them)
internal_functions_map: Dict[str, IRnode] = {}
internal_functions_ir: list[IRnode] = []

for func_ast in internal_functions:
func_ir = generate_ir_for_function(func_ast, all_sigs, global_ctx, False)
internal_functions_map[func_ast.name] = func_ir
internal_functions_ir.append(func_ir)

# for some reason, somebody may want to deploy a contract with no
# external functions, or more likely, a "pure data" contract which
# contains immutables
if len(external_functions) == 0:
# TODO: prune internal functions in this case?
runtime = ["seq"] + list(internal_functions_map.values())
return runtime, internal_functions_map
# TODO: prune internal functions in this case? dead code eliminator
# might not eliminate them, since internal function jumpdest is at the
# first instruction in the contract.
runtime = ["seq"] + internal_functions_ir
return runtime

# note: if the user does not provide one, the default fallback function
# reverts anyway. so it does not hurt to batch the payable check.
Expand Down Expand Up @@ -125,10 +127,10 @@ def _runtime_ir(runtime_functions, all_sigs, global_ctx):
["label", "fallback", ["var_list"], fallback_ir],
]

# TODO: prune unreachable functions?
runtime.extend(internal_functions_map.values())
# note: dead code eliminator will clean dead functions
runtime.extend(internal_functions_ir)

return runtime, internal_functions_map
return runtime


# take a GlobalContext, which is basically
Expand Down Expand Up @@ -159,12 +161,15 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> Tuple[IRnode, IRnode, F
runtime_functions = [f for f in function_defs if not _is_init_func(f)]
init_function = next((f for f in function_defs if _is_init_func(f)), None)

runtime, internal_functions = _runtime_ir(runtime_functions, all_sigs, global_ctx)
runtime = _runtime_ir(runtime_functions, all_sigs, global_ctx)

deploy_code: List[Any] = ["seq"]
immutables_len = global_ctx.immutable_section_bytes
if init_function:
init_func_ir = generate_ir_for_function(init_function, all_sigs, global_ctx, False)
# TODO might be cleaner to separate this into an _init_ir helper func
init_func_ir = generate_ir_for_function(
init_function, all_sigs, global_ctx, skip_nonpayable_check=False, is_ctor_context=True
)
deploy_code.append(init_func_ir)

# pass the amount of memory allocated for the init function
Expand All @@ -174,8 +179,13 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> Tuple[IRnode, IRnode, F
deploy_code.append(["deploy", init_mem_used, runtime, immutables_len])

# internal functions come after everything else
for f in init_function._metadata["type"].called_functions:
deploy_code.append(internal_functions[f.name])
internal_functions = [f for f in runtime_functions if _is_internal(f)]
for f in internal_functions:
func_ir = generate_ir_for_function(
f, all_sigs, global_ctx, skip_nonpayable_check=False, is_ctor_context=True
)
# note: we depend on dead code eliminator to clean dead function defs
deploy_code.append(func_ir)

else:
if immutables_len != 0:
Expand Down

0 comments on commit c202c4e

Please sign in to comment.