From 510125e0fce389fcc2b9993691696eb0836345b6 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 23 May 2023 17:53:31 -0400 Subject: [PATCH] fix: initcode codesize regression (#3450) this commit fixes a regression in c202c4e3ec8. the commit message states that we rely on the dead code eliminator to prune unused internal functions in the initcode, but the dead code eliminator does not prune dead code in all cases, including nested internal functions and loops. this commit reintroduces the reachability analysis in `vyper/codegen/module.py` as a stopgap until the dead code eliminator is more robust. --- tests/compiler/asm/test_asm_optimizer.py | 83 +++++++++++++++---- .../parser/functions/test_create_functions.py | 24 ++++-- vyper/codegen/module.py | 7 +- vyper/ir/compile_ir.py | 13 ++- 4 files changed, 100 insertions(+), 27 deletions(-) diff --git a/tests/compiler/asm/test_asm_optimizer.py b/tests/compiler/asm/test_asm_optimizer.py index b82d568ff8..f4a245e168 100644 --- a/tests/compiler/asm/test_asm_optimizer.py +++ b/tests/compiler/asm/test_asm_optimizer.py @@ -1,49 +1,102 @@ -from vyper.compiler.phases import CompilerData +import pytest +from vyper.compiler.phases import CompilerData -def test_dead_code_eliminator(): - code = """ +codes = [ + """ s: uint256 @internal -def foo(): +def ctor_only(): self.s = 1 @internal -def qux(): +def runtime_only(): self.s = 2 +@external +def bar(): + self.runtime_only() + +@external +def __init__(): + self.ctor_only() + """, + # code with nested function in it + """ +s: uint256 + +@internal +def runtime_only(): + self.s = 1 + +@internal +def foo(): + self.runtime_only() + +@internal +def ctor_only(): + self.s += 1 + @external def bar(): self.foo() @external def __init__(): - self.qux() + self.ctor_only() + """, + # code with loop in it, these are harder for dead code eliminator """ +s: uint256 + +@internal +def ctor_only(): + self.s = 1 + +@internal +def runtime_only(): + for i in range(10): + self.s += 1 +@external +def bar(): + self.runtime_only() + +@external +def __init__(): + self.ctor_only() + """, +] + + +@pytest.mark.parametrize("code", codes) +def test_dead_code_eliminator(code): c = CompilerData(code, no_optimize=True) initcode_asm = [i for i in c.assembly if not isinstance(i, list)] runtime_asm = c.assembly_runtime - foo_label = "_sym_internal_foo___" - qux_label = "_sym_internal_qux___" + ctor_only_label = "_sym_internal_ctor_only___" + runtime_only_label = "_sym_internal_runtime_only___" + + # qux reachable from unoptimized initcode, foo not reachable. + assert ctor_only_label + "_deploy" in initcode_asm + assert runtime_only_label + "_deploy" not in initcode_asm - # all the labels should be in all the unoptimized asms - for s in (foo_label, qux_label): - assert s + "_deploy" in initcode_asm + # all labels should be in unoptimized runtime asm + for s in (ctor_only_label, runtime_only_label): assert s + "_runtime" in runtime_asm c = CompilerData(code, no_optimize=False) initcode_asm = [i for i in c.assembly if not isinstance(i, list)] runtime_asm = c.assembly_runtime - # qux should not be in runtime code + # ctor only label should not be in runtime code for instr in runtime_asm: if isinstance(instr, str): - assert not instr.startswith(qux_label), instr + assert not instr.startswith(ctor_only_label), instr - # foo should not be in initcode asm + # runtime only label should not be in initcode asm for instr in initcode_asm: if isinstance(instr, str): - assert not instr.startswith(foo_label), instr + assert not instr.startswith(runtime_only_label), instr diff --git a/tests/parser/functions/test_create_functions.py b/tests/parser/functions/test_create_functions.py index 857173df7e..64e0823146 100644 --- a/tests/parser/functions/test_create_functions.py +++ b/tests/parser/functions/test_create_functions.py @@ -3,6 +3,8 @@ from eth.codecs import abi from hexbytes import HexBytes +import vyper.ir.compile_ir as compile_ir +from vyper.codegen.ir_node import IRnode from vyper.utils import EIP_170_LIMIT, checksum_encode, keccak256 @@ -224,15 +226,23 @@ def test(code_ofst: uint256) -> address: return create_from_blueprint(BLUEPRINT, code_offset=code_ofst) """ - # use a bunch of JUMPDEST + STOP instructions as blueprint code - # (as any STOP instruction returns valid code, split up with - # jumpdests as optimization fence) initcode_len = 100 - f = get_contract_from_ir(["deploy", 0, ["seq"] + ["jumpdest", "stop"] * (initcode_len // 2), 0]) - blueprint_code = w3.eth.get_code(f.address) - print(blueprint_code) - d = get_contract(deployer_code, f.address) + # deploy a blueprint contract whose contained initcode contains only + # zeroes (so no matter which offset, create_from_blueprint will + # return empty code) + ir = IRnode.from_list(["deploy", 0, ["seq"] + ["stop"] * initcode_len, 0]) + bytecode, _ = compile_ir.assembly_to_evm(compile_ir.compile_to_assembly(ir, no_optimize=True)) + # manually deploy the bytecode + c = w3.eth.contract(abi=[], bytecode=bytecode) + deploy_transaction = c.constructor() + tx_info = {"from": w3.eth.accounts[0], "value": 0, "gasPrice": 0} + tx_hash = deploy_transaction.transact(tx_info) + blueprint_address = w3.eth.get_transaction_receipt(tx_hash)["contractAddress"] + blueprint_code = w3.eth.get_code(blueprint_address) + print("BLUEPRINT CODE:", blueprint_code) + + d = get_contract(deployer_code, blueprint_address) # deploy with code_ofst=0 fine d.test(0) diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index 9bc589d82f..2fece47a9e 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -123,7 +123,6 @@ def _runtime_ir(runtime_functions, global_ctx): ["label", "fallback", ["var_list"], fallback_ir], ] - # note: dead code eliminator will clean dead functions runtime.extend(internal_functions_ir) return runtime @@ -178,10 +177,14 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: # internal functions come after everything else internal_functions = [f for f in runtime_functions if _is_internal(f)] for f in internal_functions: + init_func_t = init_function._metadata["type"] + if f.name not in init_func_t.recursive_calls: + # unreachable + continue + func_ir = generate_ir_for_function( f, global_ctx, skip_nonpayable_check=False, is_ctor_context=True ) - # note: we depend on dead code eliminator to clean dead function defs deploy_code.append(func_ir) else: diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index 57ea4ca7e7..b2a58fa8c9 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -758,6 +758,9 @@ def note_breakpoint(line_number_map, item, pos): line_number_map["breakpoints"].add(item.lineno + 1) +_TERMINAL_OPS = ("JUMP", "RETURN", "REVERT", "STOP", "INVALID") + + def _prune_unreachable_code(assembly): # In converting IR to assembly we sometimes end up with unreachable # instructions - POPing to clear the stack or STOPing execution at the @@ -766,9 +769,13 @@ def _prune_unreachable_code(assembly): # to avoid unnecessary bytecode bloat. changed = False i = 0 - while i < len(assembly) - 1: - if assembly[i] in ("JUMP", "RETURN", "REVERT", "STOP") and not ( - is_symbol(assembly[i + 1]) or assembly[i + 1] == "JUMPDEST" + while i < len(assembly) - 2: + instr = assembly[i] + if isinstance(instr, list): + instr = assembly[i][-1] + + if assembly[i] in _TERMINAL_OPS and not ( + is_symbol(assembly[i + 1]) and assembly[i + 2] in ("JUMPDEST", "BLANK") ): changed = True del assembly[i + 1]