Skip to content

Commit

Permalink
transient storage keyword (#3373)
Browse files Browse the repository at this point in the history
experimentally add support for transient storage via a new `transient`
keyword, which works like `immutable` or `constant`, ex.:

```vyper
my_transient_variable: transient(uint256)
```

this feature is considered experimental until py-evm adds support
(giving us the ability to actually test it). so this commit leaves the
default evm version as "shanghai" for now. it blocks the feature on
pre-cancun EVM versions, so users can't use it by accident - the only
way to use it is to explicitly enable it via `--evm-version=cancun`.
  • Loading branch information
charles-cooper authored May 18, 2023
1 parent 5b9bca2 commit ed0a654
Show file tree
Hide file tree
Showing 16 changed files with 118 additions and 16 deletions.
7 changes: 5 additions & 2 deletions tests/compiler/test_opcodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,14 @@ def test_version_check(evm_version):

def test_get_opcodes(evm_version):
ops = opcodes.get_opcodes()
if evm_version in ("paris", "berlin", "shanghai"):
if evm_version in ("paris", "berlin", "shanghai", "cancun"):
assert "CHAINID" in ops
assert ops["SLOAD"][-1] == 2100
if evm_version in ("shanghai",):
if evm_version in ("shanghai", "cancun"):
assert "PUSH0" in ops
if evm_version in ("cancun",):
assert "TLOAD" in ops
assert "TSTORE" in ops
elif evm_version == "istanbul":
assert "CHAINID" in ops
assert ops["SLOAD"][-1] == 800
Expand Down
1 change: 1 addition & 0 deletions tests/parser/ast_utils/test_ast_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def test_basic_ast():
"is_constant": False,
"is_immutable": False,
"is_public": False,
"is_transient": False,
}


Expand Down
2 changes: 2 additions & 0 deletions tests/parser/features/decorators/test_nonreentrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from vyper.exceptions import FunctionDeclarationException


# TODO test functions in this module across all evm versions
# once we have cancun support.
def test_nonreentrant_decorator(get_contract, assert_tx_failed):
calling_contract_code = """
interface SpecialContract:
Expand Down
61 changes: 61 additions & 0 deletions tests/parser/features/test_transient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pytest

from vyper.compiler import compile_code
from vyper.evm.opcodes import EVM_VERSIONS
from vyper.exceptions import StructureException

post_cancun = {k: v for k, v in EVM_VERSIONS.items() if v >= EVM_VERSIONS["cancun"]}


@pytest.mark.parametrize("evm_version", list(EVM_VERSIONS.keys()))
def test_transient_blocked(evm_version):
# test transient is blocked on pre-cancun and compiles post-cancun
code = """
my_map: transient(HashMap[address, uint256])
"""
if EVM_VERSIONS[evm_version] >= EVM_VERSIONS["cancun"]:
assert compile_code(code, evm_version=evm_version) is not None
else:
with pytest.raises(StructureException):
compile_code(code, evm_version=evm_version)


@pytest.mark.parametrize("evm_version", list(post_cancun.keys()))
def test_transient_compiles(evm_version):
# test transient keyword at least generates TLOAD/TSTORE opcodes
getter_code = """
my_map: public(transient(HashMap[address, uint256]))
"""
t = compile_code(getter_code, evm_version=evm_version, output_formats=["opcodes_runtime"])
t = t["opcodes_runtime"].split(" ")

assert "TLOAD" in t
assert "TSTORE" not in t

setter_code = """
my_map: transient(HashMap[address, uint256])
@external
def setter(k: address, v: uint256):
self.my_map[k] = v
"""
t = compile_code(setter_code, evm_version=evm_version, output_formats=["opcodes_runtime"])
t = t["opcodes_runtime"].split(" ")

assert "TLOAD" not in t
assert "TSTORE" in t

getter_setter_code = """
my_map: public(transient(HashMap[address, uint256]))
@external
def setter(k: address, v: uint256):
self.my_map[k] = v
"""
t = compile_code(
getter_setter_code, evm_version=evm_version, output_formats=["opcodes_runtime"]
)
t = t["opcodes_runtime"].split(" ")

assert "TLOAD" in t
assert "TSTORE" in t
18 changes: 14 additions & 4 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,14 +1344,23 @@ class VariableDecl(VyperNode):
If true, indicates that the variable is an immutable variable.
"""

__slots__ = ("target", "annotation", "value", "is_constant", "is_public", "is_immutable")
__slots__ = (
"target",
"annotation",
"value",
"is_constant",
"is_public",
"is_immutable",
"is_transient",
)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.is_constant = False
self.is_public = False
self.is_immutable = False
self.is_transient = False

def _check_args(annotation, call_name):
# do the same thing as `validate_call_args`
Expand All @@ -1369,9 +1378,10 @@ def _check_args(annotation, call_name):
# unwrap one layer
self.annotation = self.annotation.args[0]

if self.annotation.get("func.id") in ("immutable", "constant"):
_check_args(self.annotation, self.annotation.func.id)
setattr(self, f"is_{self.annotation.func.id}", True)
func_id = self.annotation.get("func.id")
if func_id in ("immutable", "constant", "transient"):
_check_args(self.annotation, func_id)
setattr(self, f"is_{func_id}", True)
# unwrap one layer
self.annotation = self.annotation.args[0]

Expand Down
3 changes: 2 additions & 1 deletion vyper/cli/vyper_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def _parse_args(argv):
)
parser.add_argument(
"--evm-version",
help=f"Select desired EVM version (default {DEFAULT_EVM_VERSION})",
help=f"Select desired EVM version (default {DEFAULT_EVM_VERSION}). "
" note: cancun support is EXPERIMENTAL",
choices=list(EVM_VERSIONS),
default=DEFAULT_EVM_VERSION,
dest="evm_version",
Expand Down
1 change: 1 addition & 0 deletions vyper/codegen/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class VariableRecord:
defined_at: Any = None
is_internal: bool = False
is_immutable: bool = False
is_transient: bool = False
data_offset: Optional[int] = None

def __hash__(self):
Expand Down
8 changes: 4 additions & 4 deletions vyper/codegen/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from vyper import ast as vy_ast
from vyper.codegen.ir_node import Encoding, IRnode
from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE
from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT
from vyper.evm.opcodes import version_check
from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure, TypeMismatch
from vyper.semantics.types import (
Expand Down Expand Up @@ -562,10 +562,10 @@ def _get_element_ptr_mapping(parent, key):
key = unwrap_location(key)

# TODO when is key None?
if key is None or parent.location != STORAGE:
raise TypeCheckFailure(f"bad dereference on mapping {parent}[{key}]")
if key is None or parent.location not in (STORAGE, TRANSIENT):
raise TypeCheckFailure("bad dereference on mapping {parent}[{key}]")

return IRnode.from_list(["sha3_64", parent, key], typ=subtype, location=STORAGE)
return IRnode.from_list(["sha3_64", parent, key], typ=subtype, location=parent.location)


# Take a value representing a memory or storage location, and descend down to
Expand Down
6 changes: 4 additions & 2 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)
from vyper.codegen.ir_node import IRnode
from vyper.codegen.keccak256_helper import keccak256_helper
from vyper.evm.address_space import DATA, IMMUTABLES, MEMORY, STORAGE
from vyper.evm.address_space import DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT
from vyper.evm.opcodes import version_check
from vyper.exceptions import (
CompilerPanic,
Expand Down Expand Up @@ -259,10 +259,12 @@ def parse_Attribute(self):
# self.x: global attribute
elif isinstance(self.expr.value, vy_ast.Name) and self.expr.value.id == "self":
varinfo = self.context.globals[self.expr.attr]
location = TRANSIENT if varinfo.is_transient else STORAGE

ret = IRnode.from_list(
varinfo.position.position,
typ=varinfo.typ,
location=STORAGE,
location=location,
annotation="self." + self.expr.attr,
)
ret._referenced_variables = {varinfo}
Expand Down
10 changes: 7 additions & 3 deletions vyper/codegen/function_definitions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ def get_nonreentrant_lock(func_type):

nkey = func_type.reentrancy_key_position.position

LOAD, STORE = "sload", "sstore"
if version_check(begin="cancun"):
LOAD, STORE = "tload", "tstore"

if version_check(begin="berlin"):
# any nonzero values would work here (see pricing as of net gas
# metering); these values are chosen so that downgrading to the
Expand All @@ -16,12 +20,12 @@ def get_nonreentrant_lock(func_type):
else:
final_value, temp_value = 0, 1

check_notset = ["assert", ["ne", temp_value, ["sload", nkey]]]
check_notset = ["assert", ["ne", temp_value, [LOAD, nkey]]]

if func_type.mutability == StateMutability.VIEW:
return [check_notset], [["seq"]]

else:
pre = ["seq", check_notset, ["sstore", nkey, temp_value]]
post = ["sstore", nkey, final_value]
pre = ["seq", check_notset, [STORE, nkey, temp_value]]
post = [STORE, nkey, final_value]
return [pre], [post]
1 change: 1 addition & 0 deletions vyper/evm/address_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def byte_addressable(self) -> bool:

MEMORY = AddrSpace("memory", 32, "mload", "mstore")
STORAGE = AddrSpace("storage", 1, "sload", "sstore")
TRANSIENT = AddrSpace("transient", 1, "tload", "tstore")
CALLDATA = AddrSpace("calldata", 32, "calldataload")
# immutables address space: "immutables" section of memory
# which is read-write in deploy code but then gets turned into
Expand Down
3 changes: 3 additions & 0 deletions vyper/evm/opcodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"berlin": 3,
"paris": 4,
"shanghai": 5,
"cancun": 6,
# ETC Forks
"atlantis": 0,
"agharta": 1,
Expand Down Expand Up @@ -184,6 +185,8 @@
"INVALID": (0xFE, 0, 0, 0),
"DEBUG": (0xA5, 1, 0, 0),
"BREAKPOINT": (0xA6, 0, 0, 0),
"TLOAD": (0xB3, 1, 1, 100),
"TSTORE": (0xB4, 2, 0, 100),
}

PSEUDO_OPCODES: OpcodeMap = {
Expand Down
1 change: 1 addition & 0 deletions vyper/semantics/analysis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ class VarInfo:
is_constant: bool = False
is_public: bool = False
is_immutable: bool = False
is_transient: bool = False
is_local_var: bool = False
decl_node: Optional[vy_ast.VyperNode] = None

Expand Down
9 changes: 9 additions & 0 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import vyper.builtins.interfaces
from vyper import ast as vy_ast
from vyper.evm.opcodes import version_check
from vyper.exceptions import (
CallViolation,
CompilerPanic,
Expand Down Expand Up @@ -189,17 +190,25 @@ def visit_VariableDecl(self, node):
if node.is_immutable
else DataLocation.UNSET
if node.is_constant
# XXX: needed if we want separate transient allocator
# else DataLocation.TRANSIENT
# if node.is_transient
else DataLocation.STORAGE
)

type_ = type_from_annotation(node.annotation, data_loc)

if node.is_transient and not version_check(begin="cancun"):
raise StructureException("`transient` is not available pre-cancun", node.annotation)

var_info = VarInfo(
type_,
decl_node=node,
location=data_loc,
is_constant=node.is_constant,
is_public=node.is_public,
is_immutable=node.is_immutable,
is_transient=node.is_transient,
)
node.target._metadata["varinfo"] = var_info # TODO maybe put this in the global namespace
node._metadata["type"] = type_
Expand Down
2 changes: 2 additions & 0 deletions vyper/semantics/data_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ class DataLocation(enum.Enum):
STORAGE = 2
CALLDATA = 3
CODE = 4
# XXX: needed for separate transient storage allocator
# TRANSIENT = 5
1 change: 1 addition & 0 deletions vyper/semantics/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def validate_identifier(attr):
"nonpayable",
"constant",
"immutable",
"transient",
"internal",
"payable",
"nonreentrant",
Expand Down

0 comments on commit ed0a654

Please sign in to comment.