Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[lang]: remove named reentrancy locks #3769

Merged
12 changes: 8 additions & 4 deletions docs/control-structures.rst
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -100,29 +100,33 @@ Functions marked with ``@pure`` cannot call non-``pure`` functions.
Re-entrancy Locks
-----------------

The ``@nonreentrant(<key>)`` decorator places a lock on a function, and all functions with the same ``<key>`` value. An attempt by an external contract to call back into any of these functions causes the transaction to revert.
The ``@nonreentrant`` decorator places a global nonreentrancy lock on a function. An attempt by an external contract to call back into any other ``@nonreentrant`` function causes the transaction to revert.

.. code-block:: vyper

@external
@nonreentrant("lock")
@nonreentrant
def make_a_call(_addr: address):
# this function is protected from re-entrancy
...

You can put the ``@nonreentrant(<key>)`` decorator on a ``__default__`` function but we recommend against it because in most circumstances it will not work in a meaningful way.
You can put the ``@nonreentrant`` decorator on a ``__default__`` function but we recommend against it because in most circumstances it will not work in a meaningful way.
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved

Nonreentrancy locks work by setting a specially allocated storage slot to a ``<locked>`` value on function entrance, and setting it to an ``<unlocked>`` value on function exit. On function entrance, if the storage slot is detected to be the ``<locked>`` value, execution reverts.

You cannot put the ``@nonreentrant`` decorator on a ``pure`` function. You can put it on a ``view`` function, but it only checks that the function is not in a callback (the storage slot is not in the ``<locked>`` state), as ``view`` functions can only read the state, not change it.

You can view where the nonreentrant key is physically laid out in storage by using ``vyper`` with the ``-f layout`` option (e.g., ``vyper -f layout foo.vy``). Unless it is overriden, the compiler will allocate it at slot ``0``.

.. note::
A mutable function can protect a ``view`` function from being called back into (which is useful for instance, if a ``view`` function would return inconsistent state during a mutable function), but a ``view`` function cannot protect itself from being called back into. Note that mutable functions can never be called from a ``view`` function because all external calls out from a ``view`` function are protected by the use of the ``STATICCALL`` opcode.

.. note::

A nonreentrant lock has an ``<unlocked>`` value of 3, and a ``<locked>`` value of 2. Nonzero values are used to take advantage of net gas metering - as of the Berlin hard fork, the net cost for utilizing a nonreentrant lock is 2300 gas. Prior to v0.3.4, the ``<unlocked>`` and ``<locked>`` values were 0 and 1, respectively.

.. note::
Prior to 0.4.0, nonreentrancy keys took a "key" argument for fine-grained nonreentrancy control. As of 0.4.0, only a global nonreentrancy lock is available.

The ``__default__`` Function
----------------------------
Expand Down Expand Up @@ -194,7 +198,7 @@ Decorator Description
``@pure`` Function does not read contract state or environment variables
``@view`` Function does not alter contract state
``@payable`` Function is able to receive Ether
``@nonreentrant(<unique_key>)`` Function cannot be called back into during an external call
``@nonreentrant`` Function cannot be called back into during an external call
=============================== ===========================================================

``if`` statements
Expand Down
139 changes: 109 additions & 30 deletions tests/functional/codegen/features/decorators/test_nonreentrant.py
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,103 @@

from vyper.exceptions import FunctionDeclarationException


# TODO test functions in this module across all evm versions
# once we have cancun support.


def test_nonreentrant_decorator(get_contract, tx_failed):
calling_contract_code = """
interface SpecialContract:
malicious_code = """
interface ProtectedContract:
def protected_function(callback_address: address): nonpayable

@external
def do_callback():
ProtectedContract(msg.sender).protected_function(self)
"""

protected_code = """
interface Callbackable:
def do_callback(): nonpayable

@external
@nonreentrant
def protected_function(c: Callbackable):
c.do_callback()

# add a default function so we know the callback didn't fail for any reason
# besides nonreentrancy
@external
def __default__():
pass
"""
contract = get_contract(protected_code)
malicious = get_contract(malicious_code)

with tx_failed():
contract.protected_function(malicious.address)


def test_nonreentrant_view_function(get_contract, tx_failed):
malicious_code = """
interface ProtectedContract:
def protected_function(): nonpayable
def protected_view_fn() -> uint256: view

@external
def do_callback() -> uint256:
return ProtectedContract(msg.sender).protected_view_fn()
"""

protected_code = """
interface Callbackable:
def do_callback(): nonpayable

@external
@nonreentrant
def protected_function(c: Callbackable):
c.do_callback()

@external
@nonreentrant
@view
def protected_view_fn() -> uint256:
return 10

# add a default function so we know the callback didn't fail for any reason
# besides nonreentrancy
@external
def __default__():
pass
"""
contract = get_contract(protected_code)
malicious = get_contract(malicious_code)

with tx_failed():
contract.protected_function(malicious.address)


def test_multi_function_nonreentrant(get_contract, tx_failed):
malicious_code = """
interface ProtectedContract:
def unprotected_function(val: String[100], do_callback: bool): nonpayable
def protected_function(val: String[100], do_callback: bool): nonpayable
def special_value() -> String[100]: nonpayable

@external
def updated():
SpecialContract(msg.sender).unprotected_function('surprise!', False)
ProtectedContract(msg.sender).unprotected_function('surprise!', False)

@external
def updated_protected():
# This should fail.
SpecialContract(msg.sender).protected_function('surprise protected!', False)
ProtectedContract(msg.sender).protected_function('surprise protected!', False)
"""

reentrant_code = """
protected_code = """
interface Callback:
def updated(): nonpayable
def updated_protected(): nonpayable

interface Self:
def protected_function(val: String[100], do_callback: bool) -> uint256: nonpayable
def protected_function2(val: String[100], do_callback: bool) -> uint256: nonpayable
Expand All @@ -39,7 +112,7 @@ def set_callback(c: address):
self.callback = Callback(c)

@external
@nonreentrant('protect_special_value')
@nonreentrant
def protected_function(val: String[100], do_callback: bool) -> uint256:
self.special_value = val

Expand All @@ -50,7 +123,7 @@ def protected_function(val: String[100], do_callback: bool) -> uint256:
return 2

@external
@nonreentrant('protect_special_value')
@nonreentrant
def protected_function2(val: String[100], do_callback: bool) -> uint256:
self.special_value = val
if do_callback:
Expand All @@ -60,7 +133,7 @@ def protected_function2(val: String[100], do_callback: bool) -> uint256:
return 2

@external
@nonreentrant('protect_special_value')
@nonreentrant
def protected_function3(val: String[100], do_callback: bool) -> uint256:
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved
self.special_value = val
if do_callback:
Expand All @@ -71,7 +144,8 @@ def protected_function3(val: String[100], do_callback: bool) -> uint256:


@external
@nonreentrant('protect_special_value')
@nonreentrant
@view
def protected_view_fn() -> String[100]:
return self.special_value

Expand All @@ -81,37 +155,42 @@ def unprotected_function(val: String[100], do_callback: bool):

if do_callback:
self.callback.updated()
"""

reentrant_contract = get_contract(reentrant_code)
calling_contract = get_contract(calling_contract_code)
# add a default function so we know the callback didn't fail for any reason
# besides nonreentrancy
@external
def __default__():
pass
"""
contract = get_contract(protected_code)
malicious = get_contract(malicious_code)

reentrant_contract.set_callback(calling_contract.address, transact={})
assert reentrant_contract.callback() == calling_contract.address
contract.set_callback(malicious.address, transact={})
assert contract.callback() == malicious.address

# Test unprotected function.
reentrant_contract.unprotected_function("some value", True, transact={})
assert reentrant_contract.special_value() == "surprise!"
contract.unprotected_function("some value", True, transact={})
assert contract.special_value() == "surprise!"

# Test protected function.
reentrant_contract.protected_function("some value", False, transact={})
assert reentrant_contract.special_value() == "some value"
assert reentrant_contract.protected_view_fn() == "some value"
contract.protected_function("some value", False, transact={})
assert contract.special_value() == "some value"
assert contract.protected_view_fn() == "some value"

with tx_failed():
reentrant_contract.protected_function("zzz value", True, transact={})
contract.protected_function("zzz value", True, transact={})

reentrant_contract.protected_function2("another value", False, transact={})
assert reentrant_contract.special_value() == "another value"
contract.protected_function2("another value", False, transact={})
assert contract.special_value() == "another value"

with tx_failed():
reentrant_contract.protected_function2("zzz value", True, transact={})
contract.protected_function2("zzz value", True, transact={})

reentrant_contract.protected_function3("another value", False, transact={})
assert reentrant_contract.special_value() == "another value"
contract.protected_function3("another value", False, transact={})
assert contract.special_value() == "another value"

with tx_failed():
reentrant_contract.protected_function3("zzz value", True, transact={})
contract.protected_function3("zzz value", True, transact={})


def test_nonreentrant_decorator_for_default(w3, get_contract, tx_failed):
Expand Down Expand Up @@ -145,7 +224,7 @@ def set_callback(c: address):

@external
@payable
@nonreentrant("lock")
@nonreentrant
def protected_function(val: String[100], do_callback: bool) -> uint256:
self.special_value = val
_amount: uint256 = msg.value
Expand All @@ -169,7 +248,7 @@ def unprotected_function(val: String[100], do_callback: bool):

@external
@payable
@nonreentrant("lock")
@nonreentrant
def __default__():
pass
"""
Expand Down Expand Up @@ -209,7 +288,7 @@ def test_disallow_on_init_function(get_contract):
code = """

@external
@nonreentrant("lock")
@nonreentrant
def __init__():
foo: uint256 = 0
"""
Expand Down
31 changes: 0 additions & 31 deletions tests/functional/syntax/exceptions/test_structure_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,42 +44,11 @@ def foo() -> int128:
return x.codesize()
""",
"""
@external
@nonreentrant("B")
@nonreentrant("C")
def double_nonreentrant():
pass
""",
"""
struct X:
int128[5]: int128[7]
""",
"""
@external
@nonreentrant(" ")
def invalid_nonreentrant_key():
pass
""",
"""
@external
@nonreentrant("")
def invalid_nonreentrant_key():
pass
""",
"""
@external
@nonreentrant("123")
def invalid_nonreentrant_key():
pass
""",
"""
@external
@nonreentrant("!123abcd")
def invalid_nonreentrant_key():
pass
""",
"""
@external
def foo():
true: int128 = 3
""",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,23 @@
"""
@external
@pure
@nonreentrant('lock')
@nonreentrant
def nonreentrant_foo() -> uint256:
return 1
""",
"""
@external
@nonreentrant
@nonreentrant
def nonreentrant_foo() -> uint256:
return 1
""",
"""
@external
@nonreentrant("foo")
def nonreentrant_foo() -> uint256:
return 1
""",
]


Expand Down
Loading
Loading