diff --git a/tests/compiler/asm/test_asm_optimizer.py b/tests/compiler/asm/test_asm_optimizer.py index b82d568ff8..f4a245e168 100644 --- a/tests/compiler/asm/test_asm_optimizer.py +++ b/tests/compiler/asm/test_asm_optimizer.py @@ -1,49 +1,102 @@ -from vyper.compiler.phases import CompilerData +import pytest +from vyper.compiler.phases import CompilerData -def test_dead_code_eliminator(): - code = """ +codes = [ + """ s: uint256 @internal -def foo(): +def ctor_only(): self.s = 1 @internal -def qux(): +def runtime_only(): self.s = 2 +@external +def bar(): + self.runtime_only() + +@external +def __init__(): + self.ctor_only() + """, + # code with nested function in it + """ +s: uint256 + +@internal +def runtime_only(): + self.s = 1 + +@internal +def foo(): + self.runtime_only() + +@internal +def ctor_only(): + self.s += 1 + @external def bar(): self.foo() @external def __init__(): - self.qux() + self.ctor_only() + """, + # code with loop in it, these are harder for dead code eliminator """ +s: uint256 + +@internal +def ctor_only(): + self.s = 1 + +@internal +def runtime_only(): + for i in range(10): + self.s += 1 +@external +def bar(): + self.runtime_only() + +@external +def __init__(): + self.ctor_only() + """, +] + + +@pytest.mark.parametrize("code", codes) +def test_dead_code_eliminator(code): c = CompilerData(code, no_optimize=True) initcode_asm = [i for i in c.assembly if not isinstance(i, list)] runtime_asm = c.assembly_runtime - foo_label = "_sym_internal_foo___" - qux_label = "_sym_internal_qux___" + ctor_only_label = "_sym_internal_ctor_only___" + runtime_only_label = "_sym_internal_runtime_only___" + + # qux reachable from unoptimized initcode, foo not reachable. + assert ctor_only_label + "_deploy" in initcode_asm + assert runtime_only_label + "_deploy" not in initcode_asm - # all the labels should be in all the unoptimized asms - for s in (foo_label, qux_label): - assert s + "_deploy" in initcode_asm + # all labels should be in unoptimized runtime asm + for s in (ctor_only_label, runtime_only_label): assert s + "_runtime" in runtime_asm c = CompilerData(code, no_optimize=False) initcode_asm = [i for i in c.assembly if not isinstance(i, list)] runtime_asm = c.assembly_runtime - # qux should not be in runtime code + # ctor only label should not be in runtime code for instr in runtime_asm: if isinstance(instr, str): - assert not instr.startswith(qux_label), instr + assert not instr.startswith(ctor_only_label), instr - # foo should not be in initcode asm + # runtime only label should not be in initcode asm for instr in initcode_asm: if isinstance(instr, str): - assert not instr.startswith(foo_label), instr + assert not instr.startswith(runtime_only_label), instr diff --git a/tests/parser/functions/test_create_functions.py b/tests/parser/functions/test_create_functions.py index 857173df7e..64e0823146 100644 --- a/tests/parser/functions/test_create_functions.py +++ b/tests/parser/functions/test_create_functions.py @@ -3,6 +3,8 @@ from eth.codecs import abi from hexbytes import HexBytes +import vyper.ir.compile_ir as compile_ir +from vyper.codegen.ir_node import IRnode from vyper.utils import EIP_170_LIMIT, checksum_encode, keccak256 @@ -224,15 +226,23 @@ def test(code_ofst: uint256) -> address: return create_from_blueprint(BLUEPRINT, code_offset=code_ofst) """ - # use a bunch of JUMPDEST + STOP instructions as blueprint code - # (as any STOP instruction returns valid code, split up with - # jumpdests as optimization fence) initcode_len = 100 - f = get_contract_from_ir(["deploy", 0, ["seq"] + ["jumpdest", "stop"] * (initcode_len // 2), 0]) - blueprint_code = w3.eth.get_code(f.address) - print(blueprint_code) - d = get_contract(deployer_code, f.address) + # deploy a blueprint contract whose contained initcode contains only + # zeroes (so no matter which offset, create_from_blueprint will + # return empty code) + ir = IRnode.from_list(["deploy", 0, ["seq"] + ["stop"] * initcode_len, 0]) + bytecode, _ = compile_ir.assembly_to_evm(compile_ir.compile_to_assembly(ir, no_optimize=True)) + # manually deploy the bytecode + c = w3.eth.contract(abi=[], bytecode=bytecode) + deploy_transaction = c.constructor() + tx_info = {"from": w3.eth.accounts[0], "value": 0, "gasPrice": 0} + tx_hash = deploy_transaction.transact(tx_info) + blueprint_address = w3.eth.get_transaction_receipt(tx_hash)["contractAddress"] + blueprint_code = w3.eth.get_code(blueprint_address) + print("BLUEPRINT CODE:", blueprint_code) + + d = get_contract(deployer_code, blueprint_address) # deploy with code_ofst=0 fine d.test(0) diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index 9bc589d82f..2fece47a9e 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -123,7 +123,6 @@ def _runtime_ir(runtime_functions, global_ctx): ["label", "fallback", ["var_list"], fallback_ir], ] - # note: dead code eliminator will clean dead functions runtime.extend(internal_functions_ir) return runtime @@ -178,10 +177,14 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: # internal functions come after everything else internal_functions = [f for f in runtime_functions if _is_internal(f)] for f in internal_functions: + init_func_t = init_function._metadata["type"] + if f.name not in init_func_t.recursive_calls: + # unreachable + continue + func_ir = generate_ir_for_function( f, global_ctx, skip_nonpayable_check=False, is_ctor_context=True ) - # note: we depend on dead code eliminator to clean dead function defs deploy_code.append(func_ir) else: diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index 57ea4ca7e7..b2a58fa8c9 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -758,6 +758,9 @@ def note_breakpoint(line_number_map, item, pos): line_number_map["breakpoints"].add(item.lineno + 1) +_TERMINAL_OPS = ("JUMP", "RETURN", "REVERT", "STOP", "INVALID") + + def _prune_unreachable_code(assembly): # In converting IR to assembly we sometimes end up with unreachable # instructions - POPing to clear the stack or STOPing execution at the @@ -766,9 +769,13 @@ def _prune_unreachable_code(assembly): # to avoid unnecessary bytecode bloat. changed = False i = 0 - while i < len(assembly) - 1: - if assembly[i] in ("JUMP", "RETURN", "REVERT", "STOP") and not ( - is_symbol(assembly[i + 1]) or assembly[i + 1] == "JUMPDEST" + while i < len(assembly) - 2: + instr = assembly[i] + if isinstance(instr, list): + instr = assembly[i][-1] + + if assembly[i] in _TERMINAL_OPS and not ( + is_symbol(assembly[i + 1]) and assembly[i + 2] in ("JUMPDEST", "BLANK") ): changed = True del assembly[i + 1]