Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Size limit optimizations #2175

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions tests/parser/features/test_clampers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
from eth_utils import keccak

from vyper.opcodes import EVM_VERSIONS


def _make_tx(w3, address, signature, values):
# helper function to broadcast transactions that fail clamping check
Expand Down Expand Up @@ -61,51 +63,55 @@ def foo(s: int128) -> int128:
assert_tx_failed(lambda: _make_tx(w3, c.address, "foo(int128)", [value]))


@pytest.mark.parametrize("evm_version", list(EVM_VERSIONS))
@pytest.mark.parametrize("value", [0, 1])
def test_bool_clamper_passing(w3, get_contract, value):
def test_bool_clamper_passing(w3, get_contract, value, evm_version):
code = """
@external
def foo(s: bool) -> bool:
return s
"""

c = get_contract(code)
c = get_contract(code, evm_version=evm_version)
_make_tx(w3, c.address, "foo(bool)", [value])


@pytest.mark.parametrize("evm_version", list(EVM_VERSIONS))
@pytest.mark.parametrize("value", [2, 3, 4, 8, 16, 2 ** 256 - 1])
def test_bool_clamper_failing(w3, assert_tx_failed, get_contract, value):
def test_bool_clamper_failing(w3, assert_tx_failed, get_contract, value, evm_version):
code = """
@external
def foo(s: bool) -> bool:
return s
"""

c = get_contract(code)
c = get_contract(code, evm_version=evm_version)
assert_tx_failed(lambda: _make_tx(w3, c.address, "foo(bool)", [value]))


@pytest.mark.parametrize("evm_version", list(EVM_VERSIONS))
@pytest.mark.parametrize("value", [0, 1, 2 ** 160 - 1])
def test_address_clamper_passing(w3, get_contract, value):
def test_address_clamper_passing(w3, get_contract, value, evm_version):
code = """
@external
def foo(s: address) -> address:
return s
"""

c = get_contract(code)
c = get_contract(code, evm_version=evm_version)
_make_tx(w3, c.address, "foo(address)", [value])


@pytest.mark.parametrize("evm_version", list(EVM_VERSIONS))
@pytest.mark.parametrize("value", [2 ** 160, 2 ** 256 - 1])
def test_address_clamper_failing(w3, assert_tx_failed, get_contract, value):
def test_address_clamper_failing(w3, assert_tx_failed, get_contract, value, evm_version):
code = """
@external
def foo(s: address) -> address:
return s
"""

c = get_contract(code)
c = get_contract(code, evm_version=evm_version)
assert_tx_failed(lambda: _make_tx(w3, c.address, "foo(address)", [value]))


Expand Down
10 changes: 6 additions & 4 deletions tests/parser/syntax/test_self_balance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from vyper import compiler
from vyper.opcodes import EVM_VERSIONS


Expand All @@ -17,12 +18,13 @@ def get_balance() -> uint256:
def __default__():
pass
"""
c = get_contract_with_gas_estimation(code, evm_version=evm_version)

opcodes = compiler.compile_code(code, ["opcodes"], evm_version=evm_version)["opcodes"]
if evm_version == "istanbul":
assert 0x47 in c._classic_contract.bytecode
assert "SELFBALANCE" in opcodes
else:
assert 0x47 not in c._classic_contract.bytecode
assert "SELFBALANCE" not in opcodes

c = get_contract_with_gas_estimation(code, evm_version=evm_version)
w3.eth.sendTransaction({"to": c.address, "value": 1337})

assert c.get_balance() == 1337
45 changes: 43 additions & 2 deletions vyper/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,16 @@ def has_cond_arg(node):
return node.value in ["if", "if_unchecked", "assert", "assert_reason"]


def optimize(node: LLLnode) -> LLLnode:
argz = [optimize(arg) for arg in node.args]
def optimize(lll_node: LLLnode) -> LLLnode:
lll_node = apply_general_optimizations(lll_node)
lll_node = filter_unused_sizelimits(lll_node)

return lll_node


def apply_general_optimizations(node: LLLnode) -> LLLnode:
# TODO refactor this into several functions
argz = [apply_general_optimizations(arg) for arg in node.args]
if node.value in arith and int_at(argz, 0) and int_at(argz, 1):
left, right = get_int_at(argz, 0), get_int_at(argz, 1)
# `node.value in arith` implies that `node.value` is a `str`
Expand Down Expand Up @@ -224,3 +232,36 @@ def optimize(node: LLLnode) -> LLLnode:
add_gas_estimate=node.add_gas_estimate,
valency=node.valency,
)


def filter_unused_sizelimits(lll_node: LLLnode) -> LLLnode:
# recursively search the LLL for mloads of the size limits, and then remove
# the initial mstore operations for size limits that are never referenced
expected_offsets = set(LOADED_LIMITS)
seen_offsets = _find_mload_offsets(lll_node, expected_offsets, set())
if expected_offsets == seen_offsets:
return lll_node

unseen_offsets = expected_offsets.difference(seen_offsets)
_remove_mstore(lll_node, unseen_offsets)

return lll_node


def _find_mload_offsets(lll_node: LLLnode, expected_offsets: set, seen_offsets: set) -> set:
for node in lll_node.args:
if node.value == "mload" and node.args[0].value in expected_offsets:
location = next(i for i in expected_offsets if i == node.args[0].value)
seen_offsets.add(location)
else:
seen_offsets.update(_find_mload_offsets(node, expected_offsets, seen_offsets))

return seen_offsets


def _remove_mstore(lll_node: LLLnode, offsets: set) -> None:
for node in lll_node.args.copy():
if node.value == "mstore" and node.args[0].value in offsets:
lll_node.args.remove(node)
else:
_remove_mstore(node, offsets)
19 changes: 11 additions & 8 deletions vyper/parser/arg_clamps.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import uuid

from vyper.opcodes import version_check
from vyper.parser.lll_node import LLLnode
from vyper.types.types import (
ByteArrayLike,
Expand Down Expand Up @@ -60,16 +61,18 @@ def make_arg_clamper(datapos, mempos, typ, is_init=False):
)
# Booleans: make sure they're zero or one
elif is_base_type(typ, "bool"):
return LLLnode.from_list(
["uclamplt", data_decl, 2], typ=typ, annotation="checking bool input",
)
if version_check(begin="constantinople"):
lll = ["assert", ["iszero", ["shr", 1, data_decl]]]
else:
lll = ["uclamplt", data_decl, 2]
return LLLnode.from_list(lll, typ=typ, annotation="checking bool input")
# Addresses: make sure they're in range
elif is_base_type(typ, "address"):
return LLLnode.from_list(
["uclamplt", data_decl, ["mload", MemoryPositions.ADDRSIZE]],
typ=typ,
annotation="checking address input",
)
if version_check(begin="constantinople"):
lll = ["assert", ["iszero", ["shr", 160, data_decl]]]
else:
lll = ["uclamplt", data_decl, ["mload", MemoryPositions.ADDRSIZE]]
return LLLnode.from_list(lll, typ=typ, annotation="checking address input")
# Bytes: make sure they have the right size
elif isinstance(typ, ByteArrayLike):
return LLLnode.from_list(
Expand Down