Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize nonpayable checks #2172

Merged
merged 3 commits into from
Sep 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions tests/parser/features/decorators/test_payable.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,151 @@ def test_payable_compile_fail(source, get_contract, assert_compile_failed):
assert_compile_failed(
lambda: get_contract(source), CallViolation,
)


nonpayable_code = [
"""
# single function, nonpayable
@external
def foo() -> bool:
return True
""",
"""
# multiple functions, one is payable
@external
def foo() -> bool:
return True

@payable
@external
def bar() -> bool:
return True
""",
"""
# multiple functions, nonpayable
@external
def foo() -> bool:
return True

@external
def bar() -> bool:
return True
""",
"""
# multiple functions, nonpayable (view)
@external
def foo() -> bool:
return True

@view
@external
def bar() -> bool:
return True
""",
"""
# payable init function
@external
@payable
def __init__():
a: int128 = 1

@external
def foo() -> bool:
return True
""",
"""
# payable default function
@external
@payable
def __default__():
a: int128 = 1

@external
def foo() -> bool:
return True
""",
]


@pytest.mark.parametrize("code", nonpayable_code)
def test_nonpayable_runtime_assertion(assert_tx_failed, get_contract, code):
c = get_contract(code)

c.foo(transact={"value": 0})
assert_tx_failed(lambda: c.foo(transact={"value": 10 ** 18}))


payable_code = [
"""
# single function, payable
@payable
@external
def foo() -> bool:
return True
""",
"""
# multiple functions, one is payable
@payable
@external
def foo() -> bool:
return True

@external
def bar() -> bool:
return True
""",
"""
# multiple functions, payable
@payable
@external
def foo() -> bool:
return True

@payable
@external
def bar() -> bool:
return True
""",
"""
# multiple functions, one nonpayable (view)
@payable
@external
def foo() -> bool:
return True

@view
@external
def bar() -> bool:
return True
""",
"""
# init function
@external
def __init__():
a: int128 = 1

@payable
@external
def foo() -> bool:
return True
""",
"""
# default function
@external
def __default__():
a: int128 = 1

@external
@payable
def foo() -> bool:
return True
""",
]


@pytest.mark.parametrize("code", payable_code)
def test_payable_runtime_assertion(get_contract, code):
c = get_contract(code)

c.foo(transact={"value": 10 ** 18})
c.foo(transact={"value": 0})
20 changes: 0 additions & 20 deletions tests/parser/features/test_gas.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from vyper.parser import parser_utils
from vyper.parser.parser import parse_to_lll


def test_gas_call(get_contract_with_gas_estimation):
gas_call = """
@external
Expand All @@ -13,19 +9,3 @@ def foo() -> uint256:

assert c.foo(call={"gas": 50000}) < 50000
assert c.foo(call={"gas": 50000}) > 25000

print("Passed gas test")


def test_gas_estimate_repr():
code = """
x: int128

@external
def __init__():
self.x = 1
"""
parser_utils.LLLnode.repr_show_gas = True
out = parse_to_lll(code)
assert "35261" in str(out)[:28]
parser_utils.LLLnode.repr_show_gas = False
11 changes: 0 additions & 11 deletions tests/parser/functions/test_send.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,6 @@ def fop():
assert_tx_failed(lambda: c.fop(transact={}))


def test_payable_tx_fail(assert_tx_failed, get_contract, w3):
code = """
@external
def pay_me() -> bool:
return True
"""
c = get_contract(code)

assert_tx_failed(lambda: c.pay_me(transact={"value": w3.toWei(0.1, "ether")}))


def test_default_gas(get_contract, w3):
"""
Tests to verify that send to default function will send limited gas (2300),
Expand Down
8 changes: 5 additions & 3 deletions vyper/parser/function_definitions/parse_external_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,14 @@ def validate_external_function(


def parse_external_function(
code: ast.FunctionDef, sig: FunctionSignature, context: Context
code: ast.FunctionDef, sig: FunctionSignature, context: Context, is_contract_payable: bool
) -> LLLnode:
"""
Parse a external function (FuncDef), and produce full function body.

:param sig: the FuntionSignature
:param code: ast of function
:param is_contract_payable: bool - does this contract contain payable functions?
:return: full sig compare & function body
"""

Expand All @@ -81,8 +82,9 @@ def parse_external_function(
context.memory_allocator.increase_memory(sig.max_copy_size)
clampers.append(copier)

# Add asserts for payable and internal
if sig.mutability != "payable":
if is_contract_payable and sig.mutability != "payable":
# if the contract contains payable functions, but this is not one of them
# add an assertion that the value of the call is zero
clampers.append(["assert", ["iszero", "callvalue"]])

# Fill variable positions
Expand Down
6 changes: 4 additions & 2 deletions vyper/parser/function_definitions/parse_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def is_default_func(code):
return code.name == "__default__"


def parse_function(code, sigs, origcode, global_ctx, _vars=None):
def parse_function(code, sigs, origcode, global_ctx, is_contract_payable, _vars=None):
"""
Parses a function and produces LLL code for the function, includes:
- Signature method if statement
Expand Down Expand Up @@ -54,7 +54,9 @@ def parse_function(code, sigs, origcode, global_ctx, _vars=None):
if sig.internal:
o = parse_internal_function(code=code, sig=sig, context=context,)
else:
o = parse_external_function(code=code, sig=sig, context=context,)
o = parse_external_function(
code=code, sig=sig, context=context, is_contract_payable=is_contract_payable
)

o.context = context
o.total_gas = o.gas + calc_mem_gas(o.context.memory_allocator.get_next_memory_position())
Expand Down
57 changes: 52 additions & 5 deletions vyper/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,27 @@ def parse_external_interfaces(external_interfaces, global_ctx):


def parse_other_functions(
o, otherfuncs, sigs, external_interfaces, origcode, global_ctx, default_function
o,
otherfuncs,
sigs,
external_interfaces,
origcode,
global_ctx,
default_function,
is_contract_payable,
):
sub = ["seq", func_init_lll()]
add_gas = func_init_lll().gas

for _def in otherfuncs:
sub.append(
parse_function(_def, {**{"self": sigs}, **external_interfaces}, origcode, global_ctx)
parse_function(
_def,
{**{"self": sigs}, **external_interfaces},
origcode,
global_ctx,
is_contract_payable,
)
)
sub[-1].total_gas += add_gas
add_gas += 30
Expand All @@ -113,7 +126,11 @@ def parse_other_functions(
# Add fallback function
if default_function:
default_func = parse_function(
default_function[0], {**{"self": sigs}, **external_interfaces}, origcode, global_ctx,
default_function[0],
{**{"self": sigs}, **external_interfaces},
origcode,
global_ctx,
is_contract_payable,
)
fallback = default_func
else:
Expand Down Expand Up @@ -147,6 +164,20 @@ def parse_tree_to_lll(source_code: str, global_ctx: GlobalContext) -> Tuple[LLLn
otherfuncs = [
_def for _def in global_ctx._defs if not is_initializer(_def) and not is_default_func(_def)
]

# check if any functions in the contract are payable - if not, we do a single
# ASSERT CALLVALUE ISZERO at the start of the bytecode rather than at the start
# of each function
is_contract_payable = next(
(
True
for i in global_ctx._defs
if FunctionSignature.from_definition(i, custom_structs=global_ctx._structs).mutability
== "payable"
),
False,
)

sigs: dict = {}
external_interfaces: dict = {}
# Create the main statement
Expand All @@ -160,18 +191,34 @@ def parse_tree_to_lll(source_code: str, global_ctx: GlobalContext) -> Tuple[LLLn
o.append(init_func_init_lll())
o.append(
parse_function(
initfunc[0], {**{"self": sigs}, **external_interfaces}, source_code, global_ctx,
initfunc[0],
{**{"self": sigs}, **external_interfaces},
source_code,
global_ctx,
False,
)
)

# If there are regular functions...
if otherfuncs or defaultfunc:
o, runtime = parse_other_functions(
o, otherfuncs, sigs, external_interfaces, source_code, global_ctx, defaultfunc
o,
otherfuncs,
sigs,
external_interfaces,
source_code,
global_ctx,
defaultfunc,
is_contract_payable,
)
else:
runtime = o.copy()

if not is_contract_payable:
# if no functions in the contract are payable, assert that callvalue is
# zero at the beginning of the bytecode
runtime.insert(1, ["assert", ["iszero", "callvalue"]])

# Check if interface of contract is correct.
check_valid_contract_interface(global_ctx, sigs)

Expand Down