Skip to content

Commit

Permalink
fix: initcode codesize regression (#3450)
Browse files Browse the repository at this point in the history
this commit fixes a regression in c202c4e. the commit message states
that we rely on the dead code eliminator to prune unused internal
functions in the initcode, but the dead code eliminator does not prune
dead code in all cases, including nested internal functions and loops.
this commit reintroduces the reachability analysis in
`vyper/codegen/module.py` as a stopgap until the dead code eliminator is
more robust.
  • Loading branch information
charles-cooper authored May 23, 2023
1 parent 71c8e55 commit 510125e
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 27 deletions.
83 changes: 68 additions & 15 deletions tests/compiler/asm/test_asm_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,102 @@
from vyper.compiler.phases import CompilerData
import pytest

from vyper.compiler.phases import CompilerData

def test_dead_code_eliminator():
code = """
codes = [
"""
s: uint256
@internal
def foo():
def ctor_only():
self.s = 1
@internal
def qux():
def runtime_only():
self.s = 2
@external
def bar():
self.runtime_only()
@external
def __init__():
self.ctor_only()
""",
# code with nested function in it
"""
s: uint256
@internal
def runtime_only():
self.s = 1
@internal
def foo():
self.runtime_only()
@internal
def ctor_only():
self.s += 1
@external
def bar():
self.foo()
@external
def __init__():
self.qux()
self.ctor_only()
""",
# code with loop in it, these are harder for dead code eliminator
"""
s: uint256
@internal
def ctor_only():
self.s = 1
@internal
def runtime_only():
for i in range(10):
self.s += 1
@external
def bar():
self.runtime_only()
@external
def __init__():
self.ctor_only()
""",
]


@pytest.mark.parametrize("code", codes)
def test_dead_code_eliminator(code):
c = CompilerData(code, no_optimize=True)
initcode_asm = [i for i in c.assembly if not isinstance(i, list)]
runtime_asm = c.assembly_runtime

foo_label = "_sym_internal_foo___"
qux_label = "_sym_internal_qux___"
ctor_only_label = "_sym_internal_ctor_only___"
runtime_only_label = "_sym_internal_runtime_only___"

# qux reachable from unoptimized initcode, foo not reachable.
assert ctor_only_label + "_deploy" in initcode_asm
assert runtime_only_label + "_deploy" not in initcode_asm

# all the labels should be in all the unoptimized asms
for s in (foo_label, qux_label):
assert s + "_deploy" in initcode_asm
# all labels should be in unoptimized runtime asm
for s in (ctor_only_label, runtime_only_label):
assert s + "_runtime" in runtime_asm

c = CompilerData(code, no_optimize=False)
initcode_asm = [i for i in c.assembly if not isinstance(i, list)]
runtime_asm = c.assembly_runtime

# qux should not be in runtime code
# ctor only label should not be in runtime code
for instr in runtime_asm:
if isinstance(instr, str):
assert not instr.startswith(qux_label), instr
assert not instr.startswith(ctor_only_label), instr

# foo should not be in initcode asm
# runtime only label should not be in initcode asm
for instr in initcode_asm:
if isinstance(instr, str):
assert not instr.startswith(foo_label), instr
assert not instr.startswith(runtime_only_label), instr
24 changes: 17 additions & 7 deletions tests/parser/functions/test_create_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from eth.codecs import abi
from hexbytes import HexBytes

import vyper.ir.compile_ir as compile_ir
from vyper.codegen.ir_node import IRnode
from vyper.utils import EIP_170_LIMIT, checksum_encode, keccak256


Expand Down Expand Up @@ -224,15 +226,23 @@ def test(code_ofst: uint256) -> address:
return create_from_blueprint(BLUEPRINT, code_offset=code_ofst)
"""

# use a bunch of JUMPDEST + STOP instructions as blueprint code
# (as any STOP instruction returns valid code, split up with
# jumpdests as optimization fence)
initcode_len = 100
f = get_contract_from_ir(["deploy", 0, ["seq"] + ["jumpdest", "stop"] * (initcode_len // 2), 0])
blueprint_code = w3.eth.get_code(f.address)
print(blueprint_code)

d = get_contract(deployer_code, f.address)
# deploy a blueprint contract whose contained initcode contains only
# zeroes (so no matter which offset, create_from_blueprint will
# return empty code)
ir = IRnode.from_list(["deploy", 0, ["seq"] + ["stop"] * initcode_len, 0])
bytecode, _ = compile_ir.assembly_to_evm(compile_ir.compile_to_assembly(ir, no_optimize=True))
# manually deploy the bytecode
c = w3.eth.contract(abi=[], bytecode=bytecode)
deploy_transaction = c.constructor()
tx_info = {"from": w3.eth.accounts[0], "value": 0, "gasPrice": 0}
tx_hash = deploy_transaction.transact(tx_info)
blueprint_address = w3.eth.get_transaction_receipt(tx_hash)["contractAddress"]
blueprint_code = w3.eth.get_code(blueprint_address)
print("BLUEPRINT CODE:", blueprint_code)

d = get_contract(deployer_code, blueprint_address)

# deploy with code_ofst=0 fine
d.test(0)
Expand Down
7 changes: 5 additions & 2 deletions vyper/codegen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def _runtime_ir(runtime_functions, global_ctx):
["label", "fallback", ["var_list"], fallback_ir],
]

# note: dead code eliminator will clean dead functions
runtime.extend(internal_functions_ir)

return runtime
Expand Down Expand Up @@ -178,10 +177,14 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]:
# internal functions come after everything else
internal_functions = [f for f in runtime_functions if _is_internal(f)]
for f in internal_functions:
init_func_t = init_function._metadata["type"]
if f.name not in init_func_t.recursive_calls:
# unreachable
continue

func_ir = generate_ir_for_function(
f, global_ctx, skip_nonpayable_check=False, is_ctor_context=True
)
# note: we depend on dead code eliminator to clean dead function defs
deploy_code.append(func_ir)

else:
Expand Down
13 changes: 10 additions & 3 deletions vyper/ir/compile_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,9 @@ def note_breakpoint(line_number_map, item, pos):
line_number_map["breakpoints"].add(item.lineno + 1)


_TERMINAL_OPS = ("JUMP", "RETURN", "REVERT", "STOP", "INVALID")


def _prune_unreachable_code(assembly):
# In converting IR to assembly we sometimes end up with unreachable
# instructions - POPing to clear the stack or STOPing execution at the
Expand All @@ -766,9 +769,13 @@ def _prune_unreachable_code(assembly):
# to avoid unnecessary bytecode bloat.
changed = False
i = 0
while i < len(assembly) - 1:
if assembly[i] in ("JUMP", "RETURN", "REVERT", "STOP") and not (
is_symbol(assembly[i + 1]) or assembly[i + 1] == "JUMPDEST"
while i < len(assembly) - 2:
instr = assembly[i]
if isinstance(instr, list):
instr = assembly[i][-1]

if assembly[i] in _TERMINAL_OPS and not (
is_symbol(assembly[i + 1]) and assembly[i + 2] in ("JUMPDEST", "BLANK")
):
changed = True
del assembly[i + 1]
Expand Down

0 comments on commit 510125e

Please sign in to comment.