From 8ccacb3f47f864ec2ff64d5f7ca65625e9df6b2f Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sat, 10 Feb 2024 08:39:51 -0800 Subject: [PATCH 01/12] feat[lang]: singleton modules with ownership hierarchy (#3729) this commit implements "singleton modules with ownership hierarchy" as described in https://github.com/vyperlang/vyper/issues/3722. to accomplish this, two new language constructs are added: `UsesDecl` and `InitializesDecl`. these are exposed to the user as `uses:` and `initializes:`. they are also accompanied by new `AnalysisResult` data structures: `UsesInfo` and `InitializesInfo`. `uses` and `initializes` can be thought of as a constraint system on the module system. a `uses: my-module` annotation is required if `my_module`'s state is accessed (read or written), and `initializes: my_module` is required to call `my_module.__init__()`. a module can be `use`d any number of times; it can only be `initialize`d once. a module which has been used (directly, or transitively) by the compilation target (main entry point module), must be `initialize`d exactly once. `initializes:` is also required to declare which modules it has been `initialize`d with. for example, if `mod1` declares it `uses: mod2`, then any `initializes: mod1` statement must declare *which* instance of `mod2` it has been initialized with. although there is only ever a single instance of `mod2`, this user-facing requirement improves readability by forcing the user to be aware of what the state access dependencies are for a given, `initialize`d module. the `NamedExpr` node ("walrus operator") has been added to the AST to support the initializer syntax. (note: the walrus operator is used, because the originally proposed syntax, `mod1[mod2 = mod2]` is rejected by the python parser). a new compiler pass, `vyper/semantics/analysis/global.py` has been added to implement the global initializer constraint, as it cannot be defined recursively (without a global context). since `__init__()` functions can now be called from other `__init__()` functions (which is not allowed for normal `@external` functions!), a new `@deploy` visibility has been added to vyper's visibility system. `@deploy` functions can be called from other `@deploy` functions, and never from `@external` or `@internal` functions. they also have special treatment in the ABI relative to other `@external` functions. `initializes:` is useful since it also serves the purpose of being a storage allocator directive. wherever `initializes:` is placed, is where the module will be placed in storage (and code, transient storage, or any other future storage locations). this commit refactors the storage allocator so that it recurses into child modules whenever it sees an `initializes:` statement. it refactors several data structures surrounding the storage allocator, including removing inheritance on the `DataPosition` data structure (which has also been renamed to `VarOffset`). some utility functions have been added for calculating the size of a given variable, which also get used in codegen (`get_element_ptr()`). additional work/refactoring in this commit: - new analysis machinery for detecting reads/writes for all `ExprInfo`s - dynamic programming on the `get_expr_info()` routine - refactoring of `visit_Expr`, which fixes call mutability analysis - move `StringEnum` back to vyper/utils.py - remove the "TYPE_DEFINITION" kludge in certain builtins, replace with usage of `TYPE_T` - improve `tag_exceptions()` formatting - remove `Context.globals`, as we rely on the results of the front-end analyser now. - remove dead variable: `Context.in_assertion` - refactor `generate_ir_for_function` into `generate_ir_for_external_function` and `generate_ir_for_internal_function` - move `get_nonreentrant_lock` to `function_definitions/common.py` - simplify layout allocation across locations into single function - add `VyperType.get_size_in()` and `VarInfo.get_size()` helper functions so we don't need to do as much switch/case in implementation functions - refactor `codegen/core.py` functions to use `VyperType.get_size()` - fix interfaces access from `.vyi` files --- examples/auctions/blind_auction.vy | 4 +- examples/auctions/simple_open_auction.vy | 4 +- examples/crowdfund.vy | 4 +- examples/factory/Exchange.vy | 4 +- examples/factory/Factory.vy | 4 +- .../market_maker/on_chain_market_maker.vy | 2 + examples/name_registry/name_registry.vy | 1 + .../safe_remote_purchase.vy | 4 +- examples/stock/company.vy | 4 +- examples/storage/advanced_storage.vy | 4 +- examples/storage/storage.vy | 6 +- examples/tokens/ERC1155ownable.vy | 5 +- examples/tokens/ERC20.vy | 4 +- examples/tokens/ERC4626.vy | 4 +- examples/tokens/ERC721.vy | 4 +- examples/voting/ballot.vy | 4 +- examples/wallet/wallet.vy | 4 +- tests/functional/builtins/codegen/test_abi.py | 4 +- .../builtins/codegen/test_abi_decode.py | 2 +- .../builtins/codegen/test_abi_encode.py | 2 +- .../functional/builtins/codegen/test_ceil.py | 4 +- .../builtins/codegen/test_concat.py | 4 +- .../builtins/codegen/test_create_functions.py | 10 +- .../builtins/codegen/test_ecrecover.py | 2 +- .../functional/builtins/codegen/test_floor.py | 4 +- .../builtins/codegen/test_raw_call.py | 2 +- .../functional/builtins/codegen/test_slice.py | 10 +- .../test_default_function.py | 2 +- .../calling_convention/test_erc20_abi.py | 2 +- .../test_external_contract_calls.py | 31 +- ...test_modifiable_external_contract_calls.py | 8 +- .../calling_convention/test_return_tuple.py | 2 +- .../features/decorators/test_payable.py | 4 +- .../features/decorators/test_private.py | 4 +- .../features/iteration/test_range_in.py | 2 +- .../codegen/features/test_bytes_map_keys.py | 12 +- .../codegen/features/test_clampers.py | 2 +- .../codegen/features/test_constructor.py | 22 +- .../codegen/features/test_immutable.py | 51 +- .../functional/codegen/features/test_init.py | 8 +- .../codegen/features/test_logging.py | 4 +- .../codegen/features/test_ternary.py | 2 +- .../codegen/integration/test_crowdfund.py | 4 +- .../codegen/integration/test_escrow.py | 2 +- .../codegen/modules/test_module_constants.py | 20 + .../codegen/modules/test_module_variables.py | 318 +++++ .../codegen/storage_variables/test_getters.py | 4 +- .../test_storage_variable.py | 2 +- tests/functional/codegen/test_interfaces.py | 12 +- tests/functional/codegen/types/test_bytes.py | 2 +- .../codegen/types/test_dynamic_array.py | 4 +- tests/functional/codegen/types/test_flag.py | 2 +- tests/functional/codegen/types/test_string.py | 2 +- .../test_safe_remote_purchase.py | 2 +- .../syntax/exceptions/test_call_violation.py | 9 + .../exceptions/test_constancy_exception.py | 59 +- .../test_function_declaration_exception.py | 10 +- .../test_instantiation_exception.py | 2 +- .../exceptions/test_invalid_reference.py | 2 +- .../exceptions/test_structure_exception.py | 6 +- .../exceptions/test_vyper_exception_pos.py | 2 +- .../syntax/modules/test_deploy_visibility.py | 27 + .../syntax/modules/test_implements.py | 51 + .../syntax/modules/test_initializers.py | 1139 +++++++++++++++++ tests/functional/syntax/test_address_code.py | 4 +- tests/functional/syntax/test_codehash.py | 2 +- tests/functional/syntax/test_constants.py | 4 +- tests/functional/syntax/test_immutables.py | 22 +- tests/functional/syntax/test_init.py | 64 + tests/functional/syntax/test_interfaces.py | 4 +- tests/functional/syntax/test_public.py | 2 +- tests/functional/syntax/test_tuple_assign.py | 2 +- tests/unit/ast/test_ast_dict.py | 10 - .../cli/storage_layout/test_storage_layout.py | 250 +++- tests/unit/compiler/asm/test_asm_optimizer.py | 22 +- tests/unit/compiler/test_bytecode_runtime.py | 2 +- tests/unit/semantics/test_storage_slots.py | 4 +- vyper/ast/__init__.py | 3 +- vyper/ast/grammar.lark | 14 +- vyper/ast/nodes.py | 105 +- vyper/ast/nodes.pyi | 35 +- vyper/ast/parse.py | 4 +- vyper/builtins/_signatures.py | 13 +- vyper/builtins/_utils.py | 6 +- vyper/builtins/functions.py | 18 +- vyper/codegen/context.py | 19 +- vyper/codegen/core.py | 61 +- vyper/codegen/expr.py | 37 +- .../codegen/function_definitions/__init__.py | 5 +- vyper/codegen/function_definitions/common.py | 120 +- .../function_definitions/external_function.py | 49 +- .../function_definitions/internal_function.py | 34 +- vyper/codegen/function_definitions/utils.py | 31 - vyper/codegen/module.py | 31 +- vyper/codegen/stmt.py | 2 +- vyper/compiler/phases.py | 27 +- vyper/evm/address_space.py | 8 - vyper/exceptions.py | 25 +- vyper/semantics/analysis/__init__.py | 2 +- vyper/semantics/analysis/base.py | 286 ++--- vyper/semantics/analysis/constant_folding.py | 2 +- vyper/semantics/analysis/data_positions.py | 221 ++-- vyper/semantics/analysis/global_.py | 80 ++ vyper/semantics/analysis/local.py | 228 +++- vyper/semantics/analysis/module.py | 265 +++- vyper/semantics/analysis/utils.py | 45 +- vyper/semantics/data_locations.py | 16 +- vyper/semantics/types/base.py | 23 +- vyper/semantics/types/function.py | 91 +- vyper/semantics/types/module.py | 94 +- vyper/semantics/types/utils.py | 16 +- vyper/utils.py | 56 +- 112 files changed, 3566 insertions(+), 845 deletions(-) create mode 100644 tests/functional/codegen/modules/test_module_variables.py create mode 100644 tests/functional/syntax/modules/test_deploy_visibility.py create mode 100644 tests/functional/syntax/modules/test_implements.py create mode 100644 tests/functional/syntax/modules/test_initializers.py create mode 100644 tests/functional/syntax/test_init.py delete mode 100644 vyper/codegen/function_definitions/utils.py create mode 100644 vyper/semantics/analysis/global_.py diff --git a/examples/auctions/blind_auction.vy b/examples/auctions/blind_auction.vy index 597aed57c7..966565138f 100644 --- a/examples/auctions/blind_auction.vy +++ b/examples/auctions/blind_auction.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + # Blind Auction. Adapted to Vyper from [Solidity by Example](https://github.com/ethereum/solidity/blob/develop/docs/solidity-by-example.rst#blind-auction-1) struct Bid: @@ -36,7 +38,7 @@ pendingReturns: HashMap[address, uint256] # Create a blinded auction with `_biddingTime` seconds bidding time and # `_revealTime` seconds reveal time on behalf of the beneficiary address # `_beneficiary`. -@external +@deploy def __init__(_beneficiary: address, _biddingTime: uint256, _revealTime: uint256): self.beneficiary = _beneficiary self.biddingEnd = block.timestamp + _biddingTime diff --git a/examples/auctions/simple_open_auction.vy b/examples/auctions/simple_open_auction.vy index 6d5ce06f17..499e12af16 100644 --- a/examples/auctions/simple_open_auction.vy +++ b/examples/auctions/simple_open_auction.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + # Open Auction # Auction params @@ -19,7 +21,7 @@ pendingReturns: public(HashMap[address, uint256]) # Create a simple auction with `_auction_start` and # `_bidding_time` seconds bidding time on behalf of the # beneficiary address `_beneficiary`. -@external +@deploy def __init__(_beneficiary: address, _auction_start: uint256, _bidding_time: uint256): self.beneficiary = _beneficiary self.auctionStart = _auction_start # auction start time can be in the past, present or future diff --git a/examples/crowdfund.vy b/examples/crowdfund.vy index 6d07e15bc4..50ec005924 100644 --- a/examples/crowdfund.vy +++ b/examples/crowdfund.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + ########################################################################### ## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! ########################################################################### @@ -11,7 +13,7 @@ goal: public(uint256) timelimit: public(uint256) # Setup global variables -@external +@deploy def __init__(_beneficiary: address, _goal: uint256, _timelimit: uint256): self.beneficiary = _beneficiary self.deadline = block.timestamp + _timelimit diff --git a/examples/factory/Exchange.vy b/examples/factory/Exchange.vy index 77f47984bc..e66c60743a 100644 --- a/examples/factory/Exchange.vy +++ b/examples/factory/Exchange.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + from ethereum.ercs import ERC20 @@ -9,7 +11,7 @@ token: public(ERC20) factory: Factory -@external +@deploy def __init__(_token: ERC20, _factory: Factory): self.token = _token self.factory = _factory diff --git a/examples/factory/Factory.vy b/examples/factory/Factory.vy index 50e7a81bf6..4fec723197 100644 --- a/examples/factory/Factory.vy +++ b/examples/factory/Factory.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + from ethereum.ercs import ERC20 interface Exchange: @@ -11,7 +13,7 @@ exchange_codehash: public(bytes32) exchanges: public(HashMap[ERC20, Exchange]) -@external +@deploy def __init__(_exchange_codehash: bytes32): # Register the exchange code hash during deployment of the factory self.exchange_codehash = _exchange_codehash diff --git a/examples/market_maker/on_chain_market_maker.vy b/examples/market_maker/on_chain_market_maker.vy index 4f9859584c..74b1307dc1 100644 --- a/examples/market_maker/on_chain_market_maker.vy +++ b/examples/market_maker/on_chain_market_maker.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + from ethereum.ercs import ERC20 diff --git a/examples/name_registry/name_registry.vy b/examples/name_registry/name_registry.vy index 7152851dac..937b41856b 100644 --- a/examples/name_registry/name_registry.vy +++ b/examples/name_registry/name_registry.vy @@ -1,3 +1,4 @@ +#pragma version >0.3.10 registry: HashMap[Bytes[100], address] diff --git a/examples/safe_remote_purchase/safe_remote_purchase.vy b/examples/safe_remote_purchase/safe_remote_purchase.vy index edc2163b85..91f0159a2d 100644 --- a/examples/safe_remote_purchase/safe_remote_purchase.vy +++ b/examples/safe_remote_purchase/safe_remote_purchase.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + # Safe Remote Purchase # Originally from # https://github.com/ethereum/solidity/blob/develop/docs/solidity-by-example.rst @@ -19,7 +21,7 @@ buyer: public(address) unlocked: public(bool) ended: public(bool) -@external +@deploy @payable def __init__(): assert (msg.value % 2) == 0 diff --git a/examples/stock/company.vy b/examples/stock/company.vy index 6293e6eea4..355432830d 100644 --- a/examples/stock/company.vy +++ b/examples/stock/company.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + # Financial events the contract logs event Transfer: @@ -27,7 +29,7 @@ price: public(uint256) holdings: HashMap[address, uint256] # Set up the company. -@external +@deploy def __init__(_company: address, _total_shares: uint256, initial_price: uint256): assert _total_shares > 0 assert initial_price > 0 diff --git a/examples/storage/advanced_storage.vy b/examples/storage/advanced_storage.vy index 2ba50280d7..42a455cbf1 100644 --- a/examples/storage/advanced_storage.vy +++ b/examples/storage/advanced_storage.vy @@ -1,10 +1,12 @@ +#pragma version >0.3.10 + event DataChange: setter: indexed(address) value: int128 storedData: public(int128) -@external +@deploy def __init__(_x: int128): self.storedData = _x diff --git a/examples/storage/storage.vy b/examples/storage/storage.vy index 7d05e4708c..30f570f212 100644 --- a/examples/storage/storage.vy +++ b/examples/storage/storage.vy @@ -1,9 +1,11 @@ +#pragma version >0.3.10 + storedData: public(int128) -@external +@deploy def __init__(_x: int128): self.storedData = _x @external def set(_x: int128): - self.storedData = _x \ No newline at end of file + self.storedData = _x diff --git a/examples/tokens/ERC1155ownable.vy b/examples/tokens/ERC1155ownable.vy index d1e88dcd04..d88d459d64 100644 --- a/examples/tokens/ERC1155ownable.vy +++ b/examples/tokens/ERC1155ownable.vy @@ -1,8 +1,9 @@ +#pragma version >0.3.10 + ########################################################################### ## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! ########################################################################### -# @version >=0.3.4 """ @dev example implementation of ERC-1155 non-fungible token standard ownable, with approval, OPENSEA compatible (name, symbol) @author Dr. Pixel (github: @Doc-Pixel) @@ -122,7 +123,7 @@ interface IERC1155MetadataURI: ############### functions ############### -@external +@deploy def __init__(name: String[128], symbol: String[16], uri: String[MAX_URI_LENGTH], contractUri: String[MAX_URI_LENGTH]): """ @dev contract initialization on deployment diff --git a/examples/tokens/ERC20.vy b/examples/tokens/ERC20.vy index 77550c3f5a..0e94b32b9d 100644 --- a/examples/tokens/ERC20.vy +++ b/examples/tokens/ERC20.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + ########################################################################### ## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! ########################################################################### @@ -38,7 +40,7 @@ totalSupply: public(uint256) minter: address -@external +@deploy def __init__(_name: String[32], _symbol: String[32], _decimals: uint8, _supply: uint256): init_supply: uint256 = _supply * 10 ** convert(_decimals, uint256) self.name = _name diff --git a/examples/tokens/ERC4626.vy b/examples/tokens/ERC4626.vy index 73721fdb98..699b5edd42 100644 --- a/examples/tokens/ERC4626.vy +++ b/examples/tokens/ERC4626.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + # NOTE: Copied from https://github.com/fubuloubu/ERC4626/blob/1a10b051928b11eeaad15d80397ed36603c2a49b/contracts/VyperVault.vy # example implementation of an ERC4626 vault @@ -50,7 +52,7 @@ event Withdraw: shares: uint256 -@external +@deploy def __init__(asset: ERC20): self.asset = asset diff --git a/examples/tokens/ERC721.vy b/examples/tokens/ERC721.vy index d3a8d1f13d..70dff96051 100644 --- a/examples/tokens/ERC721.vy +++ b/examples/tokens/ERC721.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + ########################################################################### ## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! ########################################################################### @@ -82,7 +84,7 @@ SUPPORTED_INTERFACES: constant(bytes4[2]) = [ 0x80ac58cd, ] -@external +@deploy def __init__(): """ @dev Contract constructor. diff --git a/examples/voting/ballot.vy b/examples/voting/ballot.vy index 107716accf..daaf712e0f 100644 --- a/examples/voting/ballot.vy +++ b/examples/voting/ballot.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + # Voting with delegation. # Information about voters @@ -50,7 +52,7 @@ def directlyVoted(addr: address) -> bool: # Setup global variables -@external +@deploy def __init__(_proposalNames: bytes32[2]): self.chairperson = msg.sender self.voterCount = 0 diff --git a/examples/wallet/wallet.vy b/examples/wallet/wallet.vy index 231f538ecf..7e92c7e89c 100644 --- a/examples/wallet/wallet.vy +++ b/examples/wallet/wallet.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + ########################################################################### ## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! ########################################################################### @@ -12,7 +14,7 @@ threshold: int128 seq: public(int128) -@external +@deploy def __init__(_owners: address[5], _threshold: int128): for i: uint256 in range(5): if _owners[i] != empty(address): diff --git a/tests/functional/builtins/codegen/test_abi.py b/tests/functional/builtins/codegen/test_abi.py index 4ddfcf50c1..335f728a37 100644 --- a/tests/functional/builtins/codegen/test_abi.py +++ b/tests/functional/builtins/codegen/test_abi.py @@ -8,14 +8,14 @@ """ x: int128 -@external +@deploy def __init__(): self.x = 1 """, """ x: int128 -@external +@deploy def __init__(): pass """, diff --git a/tests/functional/builtins/codegen/test_abi_decode.py b/tests/functional/builtins/codegen/test_abi_decode.py index 69bfef63ea..96cbbe4c2d 100644 --- a/tests/functional/builtins/codegen/test_abi_decode.py +++ b/tests/functional/builtins/codegen/test_abi_decode.py @@ -224,7 +224,7 @@ def test_side_effects_evaluation(get_contract): contract_1 = """ counter: uint256 -@external +@deploy def __init__(): self.counter = 0 diff --git a/tests/functional/builtins/codegen/test_abi_encode.py b/tests/functional/builtins/codegen/test_abi_encode.py index f4b7d57a04..8709e31470 100644 --- a/tests/functional/builtins/codegen/test_abi_encode.py +++ b/tests/functional/builtins/codegen/test_abi_encode.py @@ -263,7 +263,7 @@ def test_side_effects_evaluation(get_contract): contract_1 = """ counter: uint256 -@external +@deploy def __init__(): self.counter = 0 diff --git a/tests/functional/builtins/codegen/test_ceil.py b/tests/functional/builtins/codegen/test_ceil.py index daa9cb7c1b..191e2adfef 100644 --- a/tests/functional/builtins/codegen/test_ceil.py +++ b/tests/functional/builtins/codegen/test_ceil.py @@ -6,7 +6,7 @@ def test_ceil(get_contract_with_gas_estimation): code = """ x: decimal -@external +@deploy def __init__(): self.x = 504.0000000001 @@ -53,7 +53,7 @@ def test_ceil_negative(get_contract_with_gas_estimation): code = """ x: decimal -@external +@deploy def __init__(): self.x = -504.0000000001 diff --git a/tests/functional/builtins/codegen/test_concat.py b/tests/functional/builtins/codegen/test_concat.py index 7354515989..37bdaaaf7b 100644 --- a/tests/functional/builtins/codegen/test_concat.py +++ b/tests/functional/builtins/codegen/test_concat.py @@ -79,7 +79,7 @@ def test_concat_buffer2(get_contract): code = """ i: immutable(int256) -@external +@deploy def __init__(): i = -1 s: String[2] = concat("a", "b") @@ -99,7 +99,7 @@ def test_concat_buffer3(get_contract): s2: String[33] s3: String[34] -@external +@deploy def __init__(): self.s = "a" self.s2 = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" # 33*'a' diff --git a/tests/functional/builtins/codegen/test_create_functions.py b/tests/functional/builtins/codegen/test_create_functions.py index afa729ac8a..0aa718157c 100644 --- a/tests/functional/builtins/codegen/test_create_functions.py +++ b/tests/functional/builtins/codegen/test_create_functions.py @@ -214,7 +214,7 @@ def test_create_from_blueprint_bad_code_offset( deployer_code = """ BLUEPRINT: immutable(address) -@external +@deploy def __init__(blueprint_address: address): BLUEPRINT = blueprint_address @@ -269,7 +269,7 @@ def test_create_from_blueprint_args( FOO: immutable(String[128]) BAR: immutable(Bar) -@external +@deploy def __init__(foo: String[128], bar: Bar): FOO = foo BAR = bar @@ -450,7 +450,7 @@ def test_create_from_blueprint_complex_value( code = """ var: uint256 -@external +@deploy @payable def __init__(x: uint256): self.var = x @@ -507,7 +507,7 @@ def test_create_from_blueprint_complex_salt_raw_args( code = """ var: uint256 -@external +@deploy @payable def __init__(x: uint256): self.var = x @@ -565,7 +565,7 @@ def test_create_from_blueprint_complex_salt_no_constructor_args( code = """ var: uint256 -@external +@deploy @payable def __init__(): self.var = 12 diff --git a/tests/functional/builtins/codegen/test_ecrecover.py b/tests/functional/builtins/codegen/test_ecrecover.py index 8571948c3d..ce24868afe 100644 --- a/tests/functional/builtins/codegen/test_ecrecover.py +++ b/tests/functional/builtins/codegen/test_ecrecover.py @@ -68,7 +68,7 @@ def test_invalid_signature2(get_contract): owner: immutable(address) -@external +@deploy def __init__(): owner = 0x7E5F4552091A69125d5DfCb7b8C2659029395Bdf diff --git a/tests/functional/builtins/codegen/test_floor.py b/tests/functional/builtins/codegen/test_floor.py index d2fd993785..5caffd5551 100644 --- a/tests/functional/builtins/codegen/test_floor.py +++ b/tests/functional/builtins/codegen/test_floor.py @@ -6,7 +6,7 @@ def test_floor(get_contract_with_gas_estimation): code = """ x: decimal -@external +@deploy def __init__(): self.x = 504.0000000001 @@ -55,7 +55,7 @@ def test_floor_negative(get_contract_with_gas_estimation): code = """ x: decimal -@external +@deploy def __init__(): self.x = -504.0000000001 diff --git a/tests/functional/builtins/codegen/test_raw_call.py b/tests/functional/builtins/codegen/test_raw_call.py index b30a94502d..e5201e9bb2 100644 --- a/tests/functional/builtins/codegen/test_raw_call.py +++ b/tests/functional/builtins/codegen/test_raw_call.py @@ -137,7 +137,7 @@ def set_owner(i: int128, o: address): owners: public(address[5]) -@external +@deploy def __init__(_owner_setter: address): self.owner_setter_contract = _owner_setter diff --git a/tests/functional/builtins/codegen/test_slice.py b/tests/functional/builtins/codegen/test_slice.py index 80936bbf82..0c5a8fc485 100644 --- a/tests/functional/builtins/codegen/test_slice.py +++ b/tests/functional/builtins/codegen/test_slice.py @@ -57,7 +57,7 @@ def test_slice_immutable( IMMUTABLE_BYTES: immutable(Bytes[{length_bound}]) IMMUTABLE_SLICE: immutable(Bytes[{length_bound}]) -@external +@deploy def __init__(inp: Bytes[{length_bound}], start: uint256, length: uint256): IMMUTABLE_BYTES = inp IMMUTABLE_SLICE = slice(IMMUTABLE_BYTES, {_start}, {_length}) @@ -119,7 +119,7 @@ def test_slice_bytes_fuzz( elif location == "code": preamble = f""" IMMUTABLE_BYTES: immutable(Bytes[{length_bound}]) -@external +@deploy def __init__(foo: Bytes[{length_bound}]): IMMUTABLE_BYTES = foo """ @@ -230,7 +230,7 @@ def test_slice_immutable_length_arg(get_contract_with_gas_estimation): code = """ LENGTH: immutable(uint256) -@external +@deploy def __init__(): LENGTH = 5 @@ -314,7 +314,7 @@ def f() -> bytes32: """ foo: bytes32 -@external +@deploy def __init__(): self.foo = 0x000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f @@ -325,7 +325,7 @@ def bar() -> Bytes[{length}]: """ foo: bytes32 -@external +@deploy def __init__(): self.foo = 0x000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f diff --git a/tests/functional/codegen/calling_convention/test_default_function.py b/tests/functional/codegen/calling_convention/test_default_function.py index cf55607877..411f38eac9 100644 --- a/tests/functional/codegen/calling_convention/test_default_function.py +++ b/tests/functional/codegen/calling_convention/test_default_function.py @@ -2,7 +2,7 @@ def test_throw_on_sending(w3, tx_failed, get_contract_with_gas_estimation): code = """ x: public(int128) -@external +@deploy def __init__(): self.x = 123 """ diff --git a/tests/functional/codegen/calling_convention/test_erc20_abi.py b/tests/functional/codegen/calling_convention/test_erc20_abi.py index b9dc5c663f..59c4131fb2 100644 --- a/tests/functional/codegen/calling_convention/test_erc20_abi.py +++ b/tests/functional/codegen/calling_convention/test_erc20_abi.py @@ -33,7 +33,7 @@ def allowance(_owner: address, _spender: address) -> uint256: nonpayable token_address: ERC20Contract -@external +@deploy def __init__(token_addr: address): self.token_address = ERC20Contract(token_addr) diff --git a/tests/functional/codegen/calling_convention/test_external_contract_calls.py b/tests/functional/codegen/calling_convention/test_external_contract_calls.py index a7cf4d0ecf..8b3f30b5a5 100644 --- a/tests/functional/codegen/calling_convention/test_external_contract_calls.py +++ b/tests/functional/codegen/calling_convention/test_external_contract_calls.py @@ -41,7 +41,7 @@ def test_complicated_external_contract_calls(get_contract, get_contract_with_gas contract_1 = """ lucky: public(int128) -@external +@deploy def __init__(_lucky: int128): self.lucky = _lucky @@ -898,26 +898,31 @@ def set_lucky(arg1: address, arg2: int128): print("Successfully executed an external contract call state change") -def test_constant_external_contract_call_cannot_change_state( - assert_compile_failed, get_contract_with_gas_estimation -): +def test_constant_external_contract_call_cannot_change_state(): c = """ interface Foo: def set_lucky(_lucky: int128) -> int128: nonpayable @external @view -def set_lucky_expr(arg1: address, arg2: int128): +def set_lucky_stmt(arg1: address, arg2: int128): Foo(arg1).set_lucky(arg2) + """ + with pytest.raises(StateAccessViolation): + compile_code(c) + + c2 = """ +interface Foo: + def set_lucky(_lucky: int128) -> int128: nonpayable @external @view -def set_lucky_stmt(arg1: address, arg2: int128) -> int128: +def set_lucky_expr(arg1: address, arg2: int128) -> int128: return Foo(arg1).set_lucky(arg2) """ - assert_compile_failed(lambda: get_contract_with_gas_estimation(c), StateAccessViolation) - print("Successfully blocked an external contract call from a constant function") + with pytest.raises(StateAccessViolation): + compile_code(c2) def test_external_contract_can_be_changed_based_on_address(get_contract): @@ -968,7 +973,7 @@ def test_external_contract_calls_with_public_globals(get_contract): contract_1 = """ lucky: public(int128) -@external +@deploy def __init__(_lucky: int128): self.lucky = _lucky """ @@ -994,7 +999,7 @@ def test_external_contract_calls_with_multiple_contracts(get_contract): contract_1 = """ lucky: public(int128) -@external +@deploy def __init__(_lucky: int128): self.lucky = _lucky """ @@ -1008,7 +1013,7 @@ def lucky() -> int128: view magic_number: public(int128) -@external +@deploy def __init__(arg1: address): self.magic_number = Foo(arg1).lucky() """ @@ -1020,7 +1025,7 @@ def magic_number() -> int128: view best_number: public(int128) -@external +@deploy def __init__(arg1: address): self.best_number = Bar(arg1).magic_number() """ @@ -1145,7 +1150,7 @@ def test_invalid_contract_reference_declaration(tx_failed, get_contract): best_number: public(int128) -@external +@deploy def __init__(): pass """ diff --git a/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py b/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py index e6b2402016..aa7130fd6a 100644 --- a/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py +++ b/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py @@ -20,7 +20,7 @@ def set_lucky(_lucky: int128): view modifiable_bar_contract: ModBar static_bar_contract: ConstBar -@external +@deploy def __init__(contract_address: address): self.modifiable_bar_contract = ModBar(contract_address) self.static_bar_contract = ConstBar(contract_address) @@ -64,7 +64,7 @@ def set_lucky(_lucky: int128) -> int128: view modifiable_bar_contract: ModBar static_bar_contract: ConstBar -@external +@deploy def __init__(contract_address: address): self.modifiable_bar_contract = ModBar(contract_address) self.static_bar_contract = ConstBar(contract_address) @@ -108,7 +108,7 @@ def set_lucky(_lucky: int128): view modifiable_bar_contract: ModBar static_bar_contract: ConstBar -@external +@deploy def __init__(contract_address: address): self.modifiable_bar_contract = ModBar(contract_address) self.static_bar_contract = ConstBar(contract_address) @@ -134,7 +134,7 @@ def static_set_lucky(_lucky: int128): view modifiable_bar_contract: ModBar static_bar_contract: ConstBar -@external +@deploy def __init__(contract_address: address): self.modifiable_bar_contract = ModBar(contract_address) self.static_bar_contract = ConstBar(contract_address) diff --git a/tests/functional/codegen/calling_convention/test_return_tuple.py b/tests/functional/codegen/calling_convention/test_return_tuple.py index 266555ead6..74929c9496 100644 --- a/tests/functional/codegen/calling_convention/test_return_tuple.py +++ b/tests/functional/codegen/calling_convention/test_return_tuple.py @@ -16,7 +16,7 @@ def test_return_type(get_contract_with_gas_estimation): c: int128 chunk: Chunk -@external +@deploy def __init__(): self.chunk.a = b"hello" self.chunk.b = b"world" diff --git a/tests/functional/codegen/features/decorators/test_payable.py b/tests/functional/codegen/features/decorators/test_payable.py index ced58e1af0..955501a0e6 100644 --- a/tests/functional/codegen/features/decorators/test_payable.py +++ b/tests/functional/codegen/features/decorators/test_payable.py @@ -122,7 +122,7 @@ def bar() -> bool: """, """ # payable init function -@external +@deploy @payable def __init__(): a: int128 = 1 @@ -279,7 +279,7 @@ def baz() -> bool: """, """ # init function -@external +@deploy def __init__(): a: int128 = 1 diff --git a/tests/functional/codegen/features/decorators/test_private.py b/tests/functional/codegen/features/decorators/test_private.py index 39ea1bb9ae..193112f02b 100644 --- a/tests/functional/codegen/features/decorators/test_private.py +++ b/tests/functional/codegen/features/decorators/test_private.py @@ -120,7 +120,7 @@ def test_private_bytes(get_contract_with_gas_estimation): private_test_code = """ greeting: public(Bytes[100]) -@external +@deploy def __init__(): self.greeting = b"Hello " @@ -143,7 +143,7 @@ def test_private_statement(get_contract_with_gas_estimation): private_test_code = """ greeting: public(Bytes[20]) -@external +@deploy def __init__(): self.greeting = b"Hello " diff --git a/tests/functional/codegen/features/iteration/test_range_in.py b/tests/functional/codegen/features/iteration/test_range_in.py index 7540049778..f381f60b35 100644 --- a/tests/functional/codegen/features/iteration/test_range_in.py +++ b/tests/functional/codegen/features/iteration/test_range_in.py @@ -115,7 +115,7 @@ def test_ownership(w3, tx_failed, get_contract_with_gas_estimation): owners: address[2] -@external +@deploy def __init__(): self.owners[0] = msg.sender diff --git a/tests/functional/codegen/features/test_bytes_map_keys.py b/tests/functional/codegen/features/test_bytes_map_keys.py index 4913182d52..22df767f02 100644 --- a/tests/functional/codegen/features/test_bytes_map_keys.py +++ b/tests/functional/codegen/features/test_bytes_map_keys.py @@ -80,7 +80,7 @@ def test_extended_bytes_key_from_storage(get_contract): code = """ a: HashMap[Bytes[100000], int128] -@external +@deploy def __init__(): self.a[b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"] = 1069 @@ -114,7 +114,7 @@ def test_struct_bytes_key_memory(get_contract): a: HashMap[Bytes[100000], int128] -@external +@deploy def __init__(): self.a[b"hello"] = 1069 self.a[b"potato"] = 31337 @@ -145,7 +145,7 @@ def test_struct_bytes_key_storage(get_contract): a: HashMap[Bytes[100000], int128] b: Foo -@external +@deploy def __init__(): self.a[b"hello"] = 1069 self.a[b"potato"] = 31337 @@ -172,7 +172,7 @@ def test_bytes_key_storage(get_contract): a: HashMap[Bytes[100000], int128] b: Bytes[5] -@external +@deploy def __init__(): self.a[b"hello"] = 1069 self.b = b"hello" @@ -193,7 +193,7 @@ def test_bytes_key_calldata(get_contract): a: HashMap[Bytes[100000], int128] -@external +@deploy def __init__(): self.a[b"hello"] = 1069 @@ -215,7 +215,7 @@ def test_struct_bytes_hashmap_as_key_in_other_hashmap(get_contract): bar: public(HashMap[uint256, Thing]) foo: public(HashMap[Bytes[64], uint256]) -@external +@deploy def __init__(): self.foo[b"hello"] = 31337 self.bar[12] = Thing({name: b"hello"}) diff --git a/tests/functional/codegen/features/test_clampers.py b/tests/functional/codegen/features/test_clampers.py index 6db8570fc7..c028805c6a 100644 --- a/tests/functional/codegen/features/test_clampers.py +++ b/tests/functional/codegen/features/test_clampers.py @@ -67,7 +67,7 @@ def test_bytes_clamper_on_init(tx_failed, get_contract_with_gas_estimation): clamper_test_code = """ foo: Bytes[3] -@external +@deploy def __init__(x: Bytes[3]): self.foo = x diff --git a/tests/functional/codegen/features/test_constructor.py b/tests/functional/codegen/features/test_constructor.py index c9dfcfc5df..9146ace8a6 100644 --- a/tests/functional/codegen/features/test_constructor.py +++ b/tests/functional/codegen/features/test_constructor.py @@ -6,7 +6,7 @@ def test_init_argument_test(get_contract_with_gas_estimation): init_argument_test = """ moose: int128 -@external +@deploy def __init__(_moose: int128): self.moose = _moose @@ -26,7 +26,7 @@ def test_constructor_mapping(get_contract_with_gas_estimation): X: constant(bytes4) = 0x01ffc9a7 -@external +@deploy def __init__(): self.foo[X] = True @@ -44,7 +44,7 @@ def test_constructor_advanced_code(get_contract_with_gas_estimation): constructor_advanced_code = """ twox: int128 -@external +@deploy def __init__(x: int128): self.twox = x * 2 @@ -60,7 +60,7 @@ def test_constructor_advanced_code2(get_contract_with_gas_estimation): constructor_advanced_code2 = """ comb: uint256 -@external +@deploy def __init__(x: uint256[2], y: Bytes[3], z: uint256): self.comb = x[0] * 1000 + x[1] * 100 + len(y) * 10 + z @@ -90,7 +90,7 @@ def foo(x: int128) -> int128: def test_large_input_code_2(w3, get_contract_with_gas_estimation): large_input_code_2 = """ -@external +@deploy def __init__(x: int128): y: int128 = x @@ -113,7 +113,7 @@ def test_initialise_array_with_constant_key(get_contract_with_gas_estimation): foo: int16[X] -@external +@deploy def __init__(): self.foo[X-1] = -2 @@ -133,7 +133,7 @@ def test_initialise_dynarray_with_constant_key(get_contract_with_gas_estimation) foo: DynArray[int16, X] -@external +@deploy def __init__(): self.foo = [X - 3, X - 4, X - 5, X - 6] @@ -151,7 +151,7 @@ def test_nested_dynamic_array_constructor_arg(w3, get_contract_with_gas_estimati code = """ foo: uint256 -@external +@deploy def __init__(x: DynArray[DynArray[uint256, 3], 3]): self.foo = x[0][2] + x[1][1] + x[2][0] @@ -167,7 +167,7 @@ def test_nested_dynamic_array_constructor_arg_2(w3, get_contract_with_gas_estima code = """ foo: int128 -@external +@deploy def __init__(x: DynArray[DynArray[DynArray[int128, 3], 3], 3]): self.foo = x[0][1][2] * x[1][1][1] * x[2][1][0] - x[0][0][0] - x[1][1][1] - x[2][2][2] @@ -192,7 +192,7 @@ def test_initialise_nested_dynamic_array(w3, get_contract_with_gas_estimation): code = """ foo: DynArray[DynArray[uint256, 3], 3] -@external +@deploy def __init__(x: uint256, y: uint256, z: uint256): self.foo = [ [x, y, z], @@ -212,7 +212,7 @@ def test_initialise_nested_dynamic_array_2(w3, get_contract_with_gas_estimation) code = """ foo: DynArray[DynArray[DynArray[int128, 3], 3], 3] -@external +@deploy def __init__(x: int128, y: int128, z: int128): self.foo = [ [[x, y, z], [y, z, x], [z, y, x]], diff --git a/tests/functional/codegen/features/test_immutable.py b/tests/functional/codegen/features/test_immutable.py index 47f7fc748e..d0bc47c238 100644 --- a/tests/functional/codegen/features/test_immutable.py +++ b/tests/functional/codegen/features/test_immutable.py @@ -20,7 +20,7 @@ def test_value_storage_retrieval(typ, value, get_contract): code = f""" VALUE: immutable({typ}) -@external +@deploy def __init__(_value: {typ}): VALUE = _value @@ -41,7 +41,7 @@ def test_usage_in_constructor(get_contract, val): a: public(uint256) -@external +@deploy def __init__(_a: uint256): A = _a self.a = A @@ -63,7 +63,7 @@ def test_multiple_immutable_values(get_contract): b: immutable(address) c: immutable(String[64]) -@external +@deploy def __init__(_a: uint256, _b: address, _c: String[64]): a = _a b = _b @@ -89,7 +89,7 @@ def test_struct_immutable(get_contract): my_struct: immutable(MyStruct) -@external +@deploy def __init__(_a: uint256, _b: uint256, _c: address, _d: int256): my_struct = MyStruct({ a: _a, @@ -108,11 +108,34 @@ def get_my_struct() -> MyStruct: assert c.get_my_struct() == values +def test_complex_immutable_modifiable(get_contract): + code = """ +struct MyStruct: + a: uint256 + +my_struct: immutable(MyStruct) + +@deploy +def __init__(a: uint256): + my_struct = MyStruct({a: a}) + + # struct members are modifiable after initialization + my_struct.a += 1 + +@view +@external +def get_my_struct() -> MyStruct: + return my_struct + """ + c = get_contract(code, 1) + assert c.get_my_struct() == (2,) + + def test_list_immutable(get_contract): code = """ my_list: immutable(uint256[3]) -@external +@deploy def __init__(_a: uint256, _b: uint256, _c: uint256): my_list = [_a, _b, _c] @@ -130,7 +153,7 @@ def test_dynarray_immutable(get_contract): code = """ my_list: immutable(DynArray[uint256, 3]) -@external +@deploy def __init__(_a: uint256, _b: uint256, _c: uint256): my_list = [_a, _b, _c] @@ -154,7 +177,7 @@ def test_nested_dynarray_immutable_2(get_contract): code = """ my_list: immutable(DynArray[DynArray[uint256, 3], 3]) -@external +@deploy def __init__(_a: uint256, _b: uint256, _c: uint256): my_list = [[_a, _b, _c], [_b, _a, _c], [_c, _b, _a]] @@ -179,7 +202,7 @@ def test_nested_dynarray_immutable(get_contract): code = """ my_list: immutable(DynArray[DynArray[DynArray[int128, 3], 3], 3]) -@external +@deploy def __init__(x: int128, y: int128, z: int128): my_list = [ [[x, y, z], [y, z, x], [z, y, x]], @@ -227,7 +250,7 @@ def foo() -> uint256: counter: uint256 VALUE: immutable(uint256) -@external +@deploy def __init__(x: uint256): self.counter = x self.foo() @@ -257,7 +280,7 @@ def foo() -> uint256: b: public(uint256) @payable -@external +@deploy def __init__(to_copy: address): c: address = create_copy_of(to_copy) self.b = a @@ -281,7 +304,7 @@ def test_immutables_initialized2(get_contract, get_contract_from_ir): b: public(uint256) @payable -@external +@deploy def __init__(to_copy: address): c: address = create_copy_of(to_copy) self.b = a @@ -299,7 +322,7 @@ def test_internal_functions_called_by_ctor_location(get_contract): d: uint256 x: immutable(uint256) -@external +@deploy def __init__(): self.d = 1 x = 2 @@ -323,7 +346,7 @@ def test_nested_internal_function_immutables(get_contract): d: public(uint256) x: public(immutable(uint256)) -@external +@deploy def __init__(): self.d = 1 x = 2 @@ -348,7 +371,7 @@ def test_immutable_read_ctor_and_runtime(get_contract): d: public(uint256) x: public(immutable(uint256)) -@external +@deploy def __init__(): self.d = 1 x = 2 diff --git a/tests/functional/codegen/features/test_init.py b/tests/functional/codegen/features/test_init.py index fc765f8ab3..84d224f632 100644 --- a/tests/functional/codegen/features/test_init.py +++ b/tests/functional/codegen/features/test_init.py @@ -5,7 +5,7 @@ def test_basic_init_function(get_contract): code = """ val: public(uint256) -@external +@deploy def __init__(a: uint256): self.val = a """ @@ -27,10 +27,12 @@ def __init__(a: uint256): def test_init_calls_internal(get_contract, assert_compile_failed, tx_failed): code = """ foo: public(uint8) + @internal def bar(x: uint256) -> uint8: return convert(x, uint8) * 7 -@external + +@deploy def __init__(a: uint256): self.foo = self.bar(a) @@ -61,7 +63,7 @@ def test_nested_internal_call_from_ctor(get_contract): code = """ x: uint256 -@external +@deploy def __init__(): self.a() diff --git a/tests/functional/codegen/features/test_logging.py b/tests/functional/codegen/features/test_logging.py index 0cb8ad9abc..8b80811d02 100644 --- a/tests/functional/codegen/features/test_logging.py +++ b/tests/functional/codegen/features/test_logging.py @@ -646,7 +646,7 @@ def test_logging_fails_with_over_three_topics(tx_failed, get_contract_with_gas_e arg3: indexed(int128) arg4: indexed(int128) -@external +@deploy def __init__(): log MyLog(1, 2, 3, 4) """ @@ -1033,7 +1033,7 @@ def test_mixed_var_list_packing(get_logs, get_contract_with_gas_estimation): x: int128[4] y: int128[2] -@external +@deploy def __init__(): self.y = [1024, 2048] diff --git a/tests/functional/codegen/features/test_ternary.py b/tests/functional/codegen/features/test_ternary.py index c5480286c8..661fdc86c9 100644 --- a/tests/functional/codegen/features/test_ternary.py +++ b/tests/functional/codegen/features/test_ternary.py @@ -195,7 +195,7 @@ def test_ternary_tuple(get_contract, code, test): def test_ternary_immutable(get_contract, test): code = """ IMM: public(immutable(uint256)) -@external +@deploy def __init__(test: bool): IMM = 1 if test else 2 """ diff --git a/tests/functional/codegen/integration/test_crowdfund.py b/tests/functional/codegen/integration/test_crowdfund.py index 891ed5aebe..1a8b3f7e9f 100644 --- a/tests/functional/codegen/integration/test_crowdfund.py +++ b/tests/functional/codegen/integration/test_crowdfund.py @@ -13,7 +13,7 @@ def test_crowdfund(w3, tester, get_contract_with_gas_estimation_for_constants): refundIndex: int128 timelimit: public(uint256) -@external +@deploy def __init__(_beneficiary: address, _goal: uint256, _timelimit: uint256): self.beneficiary = _beneficiary self.deadline = block.timestamp + _timelimit @@ -109,7 +109,7 @@ def test_crowdfund2(w3, tester, get_contract_with_gas_estimation_for_constants): refundIndex: int128 timelimit: public(uint256) -@external +@deploy def __init__(_beneficiary: address, _goal: uint256, _timelimit: uint256): self.beneficiary = _beneficiary self.deadline = block.timestamp + _timelimit diff --git a/tests/functional/codegen/integration/test_escrow.py b/tests/functional/codegen/integration/test_escrow.py index 70e7cb4594..f86b4aa516 100644 --- a/tests/functional/codegen/integration/test_escrow.py +++ b/tests/functional/codegen/integration/test_escrow.py @@ -41,7 +41,7 @@ def test_arbitration_code_with_init(w3, tx_failed, get_contract_with_gas_estimat seller: address arbitrator: address -@external +@deploy @payable def __init__(_seller: address, _arbitrator: address): if self.buyer == empty(address): diff --git a/tests/functional/codegen/modules/test_module_constants.py b/tests/functional/codegen/modules/test_module_constants.py index aafbb69252..ebfefb4546 100644 --- a/tests/functional/codegen/modules/test_module_constants.py +++ b/tests/functional/codegen/modules/test_module_constants.py @@ -76,3 +76,23 @@ def foo(ix: uint256) -> uint256: assert c.foo(2) == 3 with tx_failed(): c.foo(3) + + +def test_module_constant_builtin(make_input_bundle, get_contract): + # test empty builtin, which is not (currently) foldable 2024-02-06 + mod1 = """ +X: constant(uint256) = empty(uint256) + """ + contract = """ +import mod1 + +@external +def foo() -> uint256: + return mod1.X + """ + + input_bundle = make_input_bundle({"mod1.vy": mod1}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.foo() == 0 diff --git a/tests/functional/codegen/modules/test_module_variables.py b/tests/functional/codegen/modules/test_module_variables.py new file mode 100644 index 0000000000..6bb1f9072c --- /dev/null +++ b/tests/functional/codegen/modules/test_module_variables.py @@ -0,0 +1,318 @@ +def test_simple_import(get_contract, make_input_bundle): + lib1 = """ +counter: uint256 + +@internal +def increment_counter(): + self.counter += 1 + """ + + contract = """ +import lib + +initializes: lib + +@external +def increment_counter(): + lib.increment_counter() + +@external +def get_counter() -> uint256: + return lib.counter + """ + + input_bundle = make_input_bundle({"lib.vy": lib1}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.get_counter() == 0 + c.increment_counter(transact={}) + assert c.get_counter() == 1 + + +def test_import_namespace(get_contract, make_input_bundle): + # test what happens when things in current and imported modules share names + lib = """ +counter: uint256 + +@internal +def increment_counter(): + self.counter += 1 + """ + + contract = """ +import library as lib + +counter: uint256 + +initializes: lib + +@external +def increment_counter(): + self.counter += 1 + +@external +def increment_lib_counter(): + lib.increment_counter() + +@external +def increment_lib_counter2(): + # modify lib.counter directly + lib.counter += 5 + +@external +def get_counter() -> uint256: + return self.counter + +@external +def get_lib_counter() -> uint256: + return lib.counter + """ + + input_bundle = make_input_bundle({"library.vy": lib}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.get_counter() == c.get_lib_counter() == 0 + + c.increment_counter(transact={}) + assert c.get_counter() == 1 + assert c.get_lib_counter() == 0 + + c.increment_lib_counter(transact={}) + assert c.get_lib_counter() == 1 + assert c.get_counter() == 1 + + c.increment_lib_counter2(transact={}) + assert c.get_lib_counter() == 6 + assert c.get_counter() == 1 + + +def test_init_function_side_effects(get_contract, make_input_bundle): + lib = """ +counter: uint256 + +MY_IMMUTABLE: immutable(uint256) + +@deploy +def __init__(initial_value: uint256): + self.counter = initial_value + MY_IMMUTABLE = initial_value * 2 + +@internal +def increment_counter(): + self.counter += 1 + """ + + contract = """ +import library as lib + +counter: public(uint256) + +MY_IMMUTABLE: public(immutable(uint256)) + +initializes: lib + +@deploy +def __init__(): + self.counter = 1 + MY_IMMUTABLE = 3 + lib.__init__(5) + +@external +def get_lib_counter() -> uint256: + return lib.counter + +@external +def get_lib_immutable() -> uint256: + return lib.MY_IMMUTABLE + """ + + input_bundle = make_input_bundle({"library.vy": lib}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.counter() == 1 + assert c.MY_IMMUTABLE() == 3 + assert c.get_lib_counter() == 5 + assert c.get_lib_immutable() == 10 + + +def test_indirect_variable_uses(get_contract, make_input_bundle): + lib1 = """ +counter: uint256 + +MY_IMMUTABLE: immutable(uint256) + +@deploy +def __init__(initial_value: uint256): + self.counter = initial_value + MY_IMMUTABLE = initial_value * 2 + +@internal +def increment_counter(): + self.counter += 1 + """ + lib2 = """ +import lib1 + +uses: lib1 + +@internal +def get_lib1_counter() -> uint256: + return lib1.counter + +@internal +def get_lib1_my_immutable() -> uint256: + return lib1.MY_IMMUTABLE + """ + + contract = """ +import lib1 +import lib2 + +initializes: lib1 +initializes: lib2[lib1 := lib1] + +@deploy +def __init__(): + lib1.__init__(5) + +@external +def get_storage_via_lib1() -> uint256: + return lib1.counter + +@external +def get_immutable_via_lib1() -> uint256: + return lib1.MY_IMMUTABLE + +@external +def get_storage_via_lib2() -> uint256: + return lib2.get_lib1_counter() + +@external +def get_immutable_via_lib2() -> uint256: + return lib2.get_lib1_my_immutable() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.get_storage_via_lib1() == c.get_storage_via_lib2() == 5 + assert c.get_immutable_via_lib1() == c.get_immutable_via_lib2() == 10 + + +def test_uses_already_initialized(get_contract, make_input_bundle): + lib1 = """ +counter: uint256 +MY_IMMUTABLE: immutable(uint256) + +@deploy +def __init__(initial_value: uint256): + self.counter = initial_value * 2 + MY_IMMUTABLE = initial_value * 3 + +@internal +def increment_counter(): + self.counter += 1 + """ + lib2 = """ +import lib1 + +initializes: lib1 + +@deploy +def __init__(): + lib1.__init__(5) + +@internal +def get_lib1_counter() -> uint256: + return lib1.counter + +@internal +def get_lib1_my_immutable() -> uint256: + return lib1.MY_IMMUTABLE + """ + + contract = """ +import lib1 +import lib2 + +uses: lib1 +initializes: lib2 + +@deploy +def __init__(): + lib2.__init__() + +@external +def get_storage_via_lib1() -> uint256: + return lib1.counter + +@external +def get_immutable_via_lib1() -> uint256: + return lib1.MY_IMMUTABLE + +@external +def get_storage_via_lib2() -> uint256: + return lib2.get_lib1_counter() + +@external +def get_immutable_via_lib2() -> uint256: + return lib2.get_lib1_my_immutable() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.get_storage_via_lib1() == c.get_storage_via_lib2() == 10 + assert c.get_immutable_via_lib1() == c.get_immutable_via_lib2() == 15 + + +def test_import_complex_types(get_contract, make_input_bundle): + lib1 = """ +an_array: uint256[3] +a_hashmap: HashMap[address, HashMap[uint256, uint256]] + +@internal +def set_array_value(ix: uint256, new_value: uint256): + self.an_array[ix] = new_value + +@internal +def set_hashmap_value(ix0: address, ix1: uint256, new_value: uint256): + self.a_hashmap[ix0][ix1] = new_value + """ + + contract = """ +import lib + +initializes: lib + +@external +def do_things(): + lib.set_array_value(1, 5) + lib.set_hashmap_value(msg.sender, 6, 100) + +@external +def get_array_value(ix: uint256) -> uint256: + return lib.an_array[ix] + +@external +def get_hashmap_value(ix: uint256) -> uint256: + return lib.a_hashmap[msg.sender][ix] + """ + + input_bundle = make_input_bundle({"lib.vy": lib1}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.get_array_value(0) == 0 + assert c.get_hashmap_value(0) == 0 + c.do_things(transact={}) + + assert c.get_array_value(0) == 0 + assert c.get_hashmap_value(0) == 0 + assert c.get_array_value(1) == 5 + assert c.get_hashmap_value(6) == 100 diff --git a/tests/functional/codegen/storage_variables/test_getters.py b/tests/functional/codegen/storage_variables/test_getters.py index a2d9c6d0bb..9e72bed075 100644 --- a/tests/functional/codegen/storage_variables/test_getters.py +++ b/tests/functional/codegen/storage_variables/test_getters.py @@ -41,7 +41,7 @@ def foo(): nonpayable f: public(constant(uint256[2])) = [3, 7] g: public(constant(V)) = V(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF) -@external +@deploy def __init__(): self.x = as_wei_value(7, "wei") self.y[1] = 9 @@ -87,7 +87,7 @@ def test_getter_mutability(get_contract): nyoro: public(constant(uint256)) = 2 kune: public(immutable(uint256)) -@external +@deploy def __init__(): kune = 2 """ diff --git a/tests/functional/codegen/storage_variables/test_storage_variable.py b/tests/functional/codegen/storage_variables/test_storage_variable.py index 4636fa77e0..7a22d35e4b 100644 --- a/tests/functional/codegen/storage_variables/test_storage_variable.py +++ b/tests/functional/codegen/storage_variables/test_storage_variable.py @@ -10,7 +10,7 @@ def test_permanent_variables_test(get_contract_with_gas_estimation): b: int128 var: Var -@external +@deploy def __init__(a: int128, b: int128): self.var.a = a self.var.b = b diff --git a/tests/functional/codegen/test_interfaces.py b/tests/functional/codegen/test_interfaces.py index 3344ff113b..85efe904a0 100644 --- a/tests/functional/codegen/test_interfaces.py +++ b/tests/functional/codegen/test_interfaces.py @@ -305,7 +305,7 @@ def test() -> uint256: view token_address: IToken -@external +@deploy def __init__(_token_address: address): self.token_address = IToken(_token_address) @@ -388,7 +388,7 @@ def transfer(to: address, amount: uint256) -> bool: token_address: ERC20 -@external +@deploy def __init__(_token_address: address): self.token_address = ERC20(_token_address) @@ -445,7 +445,7 @@ def should_fail() -> {typ}: view foo: BadContract -@external +@deploy def __init__(addr: BadContract): self.foo = addr @@ -501,7 +501,7 @@ def should_fail() -> Bytes[2]: view foo: BadContract -@external +@deploy def __init__(addr: BadContract): self.foo = addr @@ -551,7 +551,7 @@ def foo(x: BadJSONInterface) -> Bytes[2]: foo: BadJSONInterface -@external +@deploy def __init__(addr: BadJSONInterface): self.foo = addr @@ -667,7 +667,7 @@ def foo() -> uint256: view bar_contract: Bar -@external +@deploy def __init__(): self.bar_contract = Bar(self) diff --git a/tests/functional/codegen/types/test_bytes.py b/tests/functional/codegen/types/test_bytes.py index 325f9d7923..99e5835f6e 100644 --- a/tests/functional/codegen/types/test_bytes.py +++ b/tests/functional/codegen/types/test_bytes.py @@ -51,7 +51,7 @@ def test_test_bytes3(get_contract_with_gas_estimation): maa: Bytes[60] y: int128 -@external +@deploy def __init__(): self.x = 27 self.y = 37 diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index d3d945740b..fc3223caaf 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -1665,7 +1665,7 @@ def ix(i: uint256) -> decimal: def test_public_dynarray(get_contract): code = """ my_list: public(DynArray[uint256, 5]) -@external +@deploy def __init__(): self.my_list = [1,2,3] """ @@ -1678,7 +1678,7 @@ def __init__(): def test_nested_public_dynarray(get_contract): code = """ my_list: public(DynArray[DynArray[uint256, 5], 5]) -@external +@deploy def __init__(): self.my_list = [[1,2,3]] """ diff --git a/tests/functional/codegen/types/test_flag.py b/tests/functional/codegen/types/test_flag.py index 5da6d57558..dd9c867a96 100644 --- a/tests/functional/codegen/types/test_flag.py +++ b/tests/functional/codegen/types/test_flag.py @@ -160,7 +160,7 @@ def test_augassign_storage(get_contract, w3, tx_failed): roles: public(HashMap[address, Roles]) -@external +@deploy def __init__(): self.roles[msg.sender] = Roles.ADMIN diff --git a/tests/functional/codegen/types/test_string.py b/tests/functional/codegen/types/test_string.py index 9d50f8df38..9d596eda32 100644 --- a/tests/functional/codegen/types/test_string.py +++ b/tests/functional/codegen/types/test_string.py @@ -90,7 +90,7 @@ def test_private_string(get_contract_with_gas_estimation): private_test_code = """ greeting: public(String[100]) -@external +@deploy def __init__(): self.greeting = "Hello " diff --git a/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py b/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py index e21a113f61..f6eb3966d4 100644 --- a/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py +++ b/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py @@ -118,7 +118,7 @@ def unlocked() -> bool: view purchase_contract: PurchaseContract -@external +@deploy def __init__(_purchase_contract: address): self.purchase_contract = PurchaseContract(_purchase_contract) diff --git a/tests/functional/syntax/exceptions/test_call_violation.py b/tests/functional/syntax/exceptions/test_call_violation.py index d310a2b42a..d96df07e74 100644 --- a/tests/functional/syntax/exceptions/test_call_violation.py +++ b/tests/functional/syntax/exceptions/test_call_violation.py @@ -27,6 +27,15 @@ def goo(): def foo(): self.goo() """, + """ +@deploy +def __init__(): + pass + +@internal +def foo(): + self.__init__() + """, ] diff --git a/tests/functional/syntax/exceptions/test_constancy_exception.py b/tests/functional/syntax/exceptions/test_constancy_exception.py index 7adf9538c7..6bfb8fee57 100644 --- a/tests/functional/syntax/exceptions/test_constancy_exception.py +++ b/tests/functional/syntax/exceptions/test_constancy_exception.py @@ -78,7 +78,7 @@ def foo(): """ f:int128 -@external +@internal def a (x:int128): self.f = 100 @@ -86,6 +86,63 @@ def a (x:int128): @external def b(): self.a(10)""", + """ +interface A: + def bar() -> uint16: view +@external +@pure +def test(to:address): + a:A = A(to) + x:uint16 = a.bar() + """, + """ +interface A: + def bar() -> uint16: view +@external +@pure +def test(to:address): + a:A = A(to) + a.bar() + """, + """ +interface A: + def bar() -> uint16: nonpayable +@external +@view +def test(to:address): + a:A = A(to) + x:uint16 = a.bar() + """, + """ +interface A: + def bar() -> uint16: nonpayable +@external +@view +def test(to:address): + a:A = A(to) + a.bar() + """, + """ +a:DynArray[uint16,3] +@deploy +def __init__(): + self.a = [1,2,3] +@view +@external +def bar()->DynArray[uint16,3]: + x:uint16 = self.a.pop() + return self.a # return [1,2] + """, + """ +from ethereum.ercs import ERC20 + +token: ERC20 + +@external +@view +def topup(amount: uint256): + assert self.token.transferFrom(msg.sender, self, amount) + """, ], ) def test_statefulness_violations(bad_code): diff --git a/tests/functional/syntax/exceptions/test_function_declaration_exception.py b/tests/functional/syntax/exceptions/test_function_declaration_exception.py index 3fe23e0ec7..878c7f3e29 100644 --- a/tests/functional/syntax/exceptions/test_function_declaration_exception.py +++ b/tests/functional/syntax/exceptions/test_function_declaration_exception.py @@ -34,17 +34,17 @@ def test_func() -> int128: return (1, 2) """, """ -@external +@deploy def __init__(a: int128 = 12): pass """, """ -@external +@deploy def __init__() -> uint256: return 1 """, """ -@external +@deploy def __init__() -> bool: pass """, @@ -58,7 +58,7 @@ def __init__(): """ a: immutable(uint256) -@external +@deploy @pure def __init__(): a = 1 @@ -66,7 +66,7 @@ def __init__(): """ a: immutable(uint256) -@external +@deploy @view def __init__(): a = 1 diff --git a/tests/functional/syntax/exceptions/test_instantiation_exception.py b/tests/functional/syntax/exceptions/test_instantiation_exception.py index 0d641f154a..4dd0bf6e02 100644 --- a/tests/functional/syntax/exceptions/test_instantiation_exception.py +++ b/tests/functional/syntax/exceptions/test_instantiation_exception.py @@ -69,7 +69,7 @@ def foo(): """ b: immutable(HashMap[uint256, uint256]) -@external +@deploy def __init__(): b = empty(HashMap[uint256, uint256]) """, diff --git a/tests/functional/syntax/exceptions/test_invalid_reference.py b/tests/functional/syntax/exceptions/test_invalid_reference.py index fe315e5cbf..7519d1406e 100644 --- a/tests/functional/syntax/exceptions/test_invalid_reference.py +++ b/tests/functional/syntax/exceptions/test_invalid_reference.py @@ -47,7 +47,7 @@ def foo(): """ a: public(immutable(uint256)) -@external +@deploy def __init__(): a = 123 diff --git a/tests/functional/syntax/exceptions/test_structure_exception.py b/tests/functional/syntax/exceptions/test_structure_exception.py index c6d733fc90..afc7a35012 100644 --- a/tests/functional/syntax/exceptions/test_structure_exception.py +++ b/tests/functional/syntax/exceptions/test_structure_exception.py @@ -94,7 +94,7 @@ def foo(): a: immutable(uint256) n: public(HashMap[uint256, bool][a]) -@external +@deploy def __init__(): a = 3 """, @@ -105,14 +105,14 @@ def __init__(): m1: HashMap[uint8, uint8] m2: HashMap[uint8, uint8] -@external +@deploy def __init__(): self.m1 = self.m2 """, """ m1: HashMap[uint8, uint8] -@external +@deploy def __init__(): self.m1 = 234 """, diff --git a/tests/functional/syntax/exceptions/test_vyper_exception_pos.py b/tests/functional/syntax/exceptions/test_vyper_exception_pos.py index a261cb0a11..9e0767cb83 100644 --- a/tests/functional/syntax/exceptions/test_vyper_exception_pos.py +++ b/tests/functional/syntax/exceptions/test_vyper_exception_pos.py @@ -22,7 +22,7 @@ def test_multiple_exceptions(get_contract, assert_compile_failed): foo: immutable(uint256) bar: immutable(uint256) -@external +@deploy def __init__(): self.foo = 1 # SyntaxException self.bar = 2 # SyntaxException diff --git a/tests/functional/syntax/modules/test_deploy_visibility.py b/tests/functional/syntax/modules/test_deploy_visibility.py new file mode 100644 index 0000000000..f51bf9575b --- /dev/null +++ b/tests/functional/syntax/modules/test_deploy_visibility.py @@ -0,0 +1,27 @@ +import pytest + +from vyper.compiler import compile_code +from vyper.exceptions import CallViolation + + +def test_call_deploy_from_external(make_input_bundle): + lib1 = """ +@deploy +def __init__(): + pass + """ + + main = """ +import lib1 + +@external +def foo(): + lib1.__init__() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(CallViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value.message == "Cannot call an @deploy function from an @external function!" diff --git a/tests/functional/syntax/modules/test_implements.py b/tests/functional/syntax/modules/test_implements.py new file mode 100644 index 0000000000..c292e198d9 --- /dev/null +++ b/tests/functional/syntax/modules/test_implements.py @@ -0,0 +1,51 @@ +from vyper.compiler import compile_code + + +def test_implements_from_vyi(make_input_bundle): + vyi = """ +@external +def foo(): + ... + """ + lib1 = """ +import some_interface + """ + main = """ +import lib1 + +implements: lib1.some_interface + +@external +def foo(): # implementation + pass + """ + input_bundle = make_input_bundle({"some_interface.vyi": vyi, "lib1.vy": lib1}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_implements_from_vyi2(make_input_bundle): + # test implements via nested imported vyi file + vyi = """ +@external +def foo(): + ... + """ + lib1 = """ +import some_interface + """ + lib2 = """ +import lib1 + """ + main = """ +import lib2 + +implements: lib2.lib1.some_interface + +@external +def foo(): # implementation + pass + """ + input_bundle = make_input_bundle({"some_interface.vyi": vyi, "lib1.vy": lib1, "lib2.vy": lib2}) + + assert compile_code(main, input_bundle=input_bundle) is not None diff --git a/tests/functional/syntax/modules/test_initializers.py b/tests/functional/syntax/modules/test_initializers.py new file mode 100644 index 0000000000..a12f5f57ea --- /dev/null +++ b/tests/functional/syntax/modules/test_initializers.py @@ -0,0 +1,1139 @@ +""" +tests for the uses/initializes checker +main properties to test: +- state usage -- if a module uses state, it must `used` or `initialized` +- conversely, if a module does not touch state, it should not be `used` +- global initializer check: each used module is `initialized` exactly once +""" + +import pytest + +from vyper.compiler import compile_code +from vyper.exceptions import ( + BorrowException, + ImmutableViolation, + InitializerException, + StructureException, + UndeclaredDefinition, +) + + +def test_initialize_uses(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + pass + """ + lib2 = """ +import lib1 + +uses: lib1 + +counter: uint256 + +@deploy +def __init__(): + pass + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib2 +import lib1 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + +@deploy +def __init__(): + lib1.__init__() + lib2.__init__() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_initialize_multiple_uses(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + pass + """ + lib2 = """ +totalSupply: uint256 + """ + lib3 = """ +import lib1 +import lib2 + +# multiple uses on one line +uses: ( + lib1, + lib2 +) + +counter: uint256 + +@deploy +def __init__(): + pass + +@internal +def foo(): + x: uint256 = lib2.totalSupply + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 +import lib3 + +initializes: lib1 +initializes: lib2 +initializes: lib3[ + lib1 := lib1, + lib2 := lib2 +] + +@deploy +def __init__(): + lib1.__init__() + lib3.__init__() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2, "lib3.vy": lib3}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_initialize_multi_line_uses(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + pass + """ + lib2 = """ +totalSupply: uint256 + """ + lib3 = """ +import lib1 +import lib2 + +uses: lib1 +uses: lib2 + +counter: uint256 + +@deploy +def __init__(): + pass + +@internal +def foo(): + x: uint256 = lib2.totalSupply + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 +import lib3 + +initializes: lib1 +initializes: lib2 +initializes: lib3[ + lib1 := lib1, + lib2 := lib2 +] + +@deploy +def __init__(): + lib1.__init__() + lib3.__init__() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2, "lib3.vy": lib3}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_initialize_uses_attribute(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + pass + """ + lib2 = """ +import lib1 + +uses: lib1 + +counter: uint256 + +@deploy +def __init__(): + pass + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + +@deploy +def __init__(): + lib2.__init__() + # demonstrate we can call lib1.__init__ through lib2.lib1 + # (not sure this should be allowed, really. + lib2.lib1.__init__() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_initializes_without_init_function(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +uses: lib1 + +counter: uint256 + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + +@deploy +def __init__(): + pass + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_imported_as_different_names(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 as m + +uses: m + +counter: uint256 + +@internal +def foo(): + m.counter += 1 + """ + main = """ +import lib1 as some_module +import lib2 + +initializes: lib2[m := some_module] +initializes: some_module + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_initializer_list_module_mismatch(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +something: uint256 + """ + lib3 = """ +import lib1 + +uses: lib1 + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 +import lib3 + +initializes: lib1 +initializes: lib3[lib1 := lib2] # typo -- should be [lib1 := lib1] + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2, "lib3.vy": lib3}) + + with pytest.raises(StructureException) as e: + assert compile_code(main, input_bundle=input_bundle) is not None + + assert e.value._message == "lib1 is not lib2!" + + +def test_imported_as_different_names_error(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 as m + +uses: m + +counter: uint256 + +@internal +def foo(): + m.counter += 1 + """ + main = """ +import lib1 +import lib2 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(UndeclaredDefinition) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "unknown module `lib1`" + assert e.value._hint == "did you mean `m := lib1`?" + + +def test_global_initializer_constraint(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +uses: lib1 + +counter: uint256 + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 + +initializes: lib2[lib1 := lib1] +# forgot to initialize lib1! + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(InitializerException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "module `lib1.vy` is used but never initialized!" + assert e.value._hint == "add `initializes: lib1` to the top level of your main contract" + + +def test_initializer_no_references(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +uses: lib1 + +counter: uint256 + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 + +initializes: lib2 +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(InitializerException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "`lib2` uses `lib1`, but it is not initialized with `lib1`" + assert e.value._hint == "add `lib1` to its initializer list" + + +def test_missing_uses(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! + +counter: uint256 + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 + +initializes: lib2 +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_for_read(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! + +counter: uint256 + +@internal +def foo() -> uint256: + return lib1.counter + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_for_read_immutable(make_input_bundle): + lib1 = """ +MY_IMMUTABLE: immutable(uint256) + +@deploy +def __init__(): + MY_IMMUTABLE = 7 + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! + +counter: uint256 + +@internal +def foo() -> uint256: + return lib1.MY_IMMUTABLE + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_for_read_inside_call(make_input_bundle): + lib1 = """ +MY_IMMUTABLE: immutable(uint256) + +@deploy +def __init__(): + MY_IMMUTABLE = 9 + +@internal +def get_counter() -> uint256: + return MY_IMMUTABLE + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! + +counter: uint256 + +@internal +def foo() -> uint256: + return lib1.get_counter() + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_for_hashmap(make_input_bundle): + lib1 = """ +counter: HashMap[uint256, HashMap[uint256, uint256]] + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! + +@internal +def foo() -> uint256: + return lib1.counter[1][2] + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_for_tuple(make_input_bundle): + lib1 = """ +counter: HashMap[uint256, HashMap[uint256, uint256]] + """ + lib2 = """ +import lib1 + +interface Foo: + def foo() -> (uint256, uint256): nonpayable + +something: uint256 + +# forgot `uses: lib1`! + +@internal +def foo() -> uint256: + lib1.counter[1][2], self.something = Foo(msg.sender).foo() + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 +initializes: lib2 + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_for_tuple_function_call(make_input_bundle): + lib1 = """ +counter: HashMap[uint256, HashMap[uint256, uint256]] + +something: uint256 + +interface Foo: + def foo() -> (uint256, uint256): nonpayable + +@internal +def write_tuple(): + self.counter[1][2], self.something = Foo(msg.sender).foo() + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! +@internal +def foo(): + lib1.write_tuple() + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 +initializes: lib2 + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_function_call(make_input_bundle): + # test missing uses through function call + lib1 = """ +counter: uint256 + +@internal +def update_counter(new_value: uint256): + self.counter = new_value + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! + +counter: uint256 + +@internal +def foo(): + lib1.update_counter(lib1.counter + 1) + """ + main = """ +import lib1 +import lib2 + +initializes: lib2 +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_nested_attribute(make_input_bundle): + # test missing uses through nested attribute access + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +counter: uint256 + +@internal +def foo(): + pass + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +# did not `use` or `initialize` lib2! + +@external +def foo(new_value: uint256): + # cannot access lib1 state through lib2 + lib2.lib1.counter = new_value + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib2` state!" + + expected_hint = "add `uses: lib2` or `initializes: lib2` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_nested_attribute_function_call(make_input_bundle): + # test missing uses through nested attribute access + lib1 = """ +counter: uint256 + +@internal +def update_counter(new_value: uint256): + self.counter = new_value + """ + lib2 = """ +import lib1 + +counter: uint256 + +@internal +def foo(): + pass + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +# did not `use` or `initialize` lib2! + +@external +def foo(new_value: uint256): + # cannot access lib1 state through lib2 + lib2.lib1.update_counter(new_value) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib2` state!" + + expected_hint = "add `uses: lib2` or `initializes: lib2` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_uses_skip_import(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +@internal +def foo(): + pass + """ + main = """ +import lib1 +import lib2 + +initializes: lib2 + +@external +def foo(new_value: uint256): + # can access lib1 state through lib2? + lib2.lib1.counter = new_value + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_invalid_uses(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +uses: lib1 # not necessary! + +counter: uint256 + +@internal +def foo(): + pass + """ + main = """ +import lib1 +import lib2 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(BorrowException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "`lib1` is declared as used, but it is not actually used in lib2.vy!" + assert e.value._hint == "delete `uses: lib1`" + + +def test_invalid_uses2(make_input_bundle): + # test a more complicated invalid uses + lib1 = """ +counter: uint256 + +@internal +def foo(addr: address): + # sends value -- modifies ethereum state + to_send_value: uint256 = 100 + raw_call(addr, b"someFunction()", value=to_send_value) + """ + lib2 = """ +import lib1 + +uses: lib1 # not necessary! + +counter: uint256 + +@internal +def foo(): + lib1.foo(msg.sender) + """ + main = """ +import lib1 +import lib2 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + +@external +def foo(): + lib2.foo() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(BorrowException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "`lib1` is declared as used, but it is not actually used in lib2.vy!" + assert e.value._hint == "delete `uses: lib1`" + + +def test_initializes_uses_conflict(make_input_bundle): + lib1 = """ +counter: uint256 + """ + main = """ +import lib1 + +initializes: lib1 +uses: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(StructureException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "ownership already set to `initializes`" + + +def test_uses_initializes_conflict(make_input_bundle): + lib1 = """ +counter: uint256 + """ + main = """ +import lib1 + +uses: lib1 +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(StructureException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "ownership already set to `uses`" + + +def test_uses_twice(make_input_bundle): + lib1 = """ +counter: uint256 + """ + main = """ +import lib1 + +uses: lib1 + +random_variable: constant(uint256) = 3 + +uses: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(StructureException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "ownership already set to `uses`" + + +def test_initializes_twice(make_input_bundle): + lib1 = """ +counter: uint256 + """ + main = """ +import lib1 + +initializes: lib1 + +random_variable: constant(uint256) = 3 + +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(StructureException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "ownership already set to `initializes`" + + +def test_no_initialize_unused_module(make_input_bundle): + lib1 = """ +counter: uint256 + +@internal +def set_counter(new_value: uint256): + self.counter = new_value + +@internal +@pure +def add(x: uint256, y: uint256) -> uint256: + return x + y + """ + main = """ +import lib1 + +# not needed: `initializes: lib1` + +@external +def do_add(x: uint256, y: uint256) -> uint256: + return lib1.add(x, y) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_no_initialize_unused_module2(make_input_bundle): + # slightly more complicated + lib1 = """ +counter: uint256 + +@internal +def set_counter(new_value: uint256): + self.counter = new_value + +@internal +@pure +def add(x: uint256, y: uint256) -> uint256: + return x + y + """ + lib2 = """ +import lib1 + +@internal +@pure +def addmul(x: uint256, y: uint256, z: uint256) -> uint256: + return lib1.add(x, y) * z + """ + main = """ +import lib1 +import lib2 + +@external +def do_addmul(x: uint256, y: uint256) -> uint256: + return lib2.addmul(x, y, 5) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_init_uninitialized_function(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + pass + """ + main = """ +import lib1 + +# missing `initializes: lib1`! + +@deploy +def __init__(): + lib1.__init__() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(InitializerException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "tried to initialize `lib1`, but it is not in initializer list!" + assert e.value._hint == "add `initializes: lib1` as a top-level statement to your contract" + + +def test_init_uninitialized_function2(make_input_bundle): + # test that we can't call module.__init__() even when we call `uses` + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + pass + """ + main = """ +import lib1 + +uses: lib1 +# missing `initializes: lib1`! + +@deploy +def __init__(): + lib1.__init__() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(InitializerException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "tried to initialize `lib1`, but it is not in initializer list!" + assert e.value._hint == "add `initializes: lib1` as a top-level statement to your contract" + + +def test_noinit_initialized_function(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + self.counter = 5 + """ + main = """ +import lib1 + +initializes: lib1 + +@deploy +def __init__(): + pass # missing `lib1.__init__()`! + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(InitializerException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "not initialized!" + assert e.value._hint == "add `lib1.__init__()` to your `__init__()` function" + + +def test_noinit_initialized_function2(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + self.counter = 5 + """ + main = """ +import lib1 + +initializes: lib1 + +# missing `lib1.__init__()`! + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(InitializerException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "not initialized!" + assert e.value._hint == "add `lib1.__init__()` to your `__init__()` function" + + +def test_ownership_decl_errors_not_swallowed(make_input_bundle): + lib1 = """ +counter: uint256 + """ + main = """ +import lib1 +# forgot to import lib2 + +uses: (lib1, lib2) # should get UndeclaredDefinition + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(UndeclaredDefinition) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "'lib2' has not been declared. " diff --git a/tests/functional/syntax/test_address_code.py b/tests/functional/syntax/test_address_code.py index fa6ed20117..5873eb5af8 100644 --- a/tests/functional/syntax/test_address_code.py +++ b/tests/functional/syntax/test_address_code.py @@ -165,7 +165,7 @@ def test_address_code_self_success(get_contract, optimize): code = """ code_deployment: public(Bytes[32]) -@external +@deploy def __init__(): self.code_deployment = slice(self.code, 0, 32) @@ -186,7 +186,7 @@ def test_address_code_self_runtime_error_deployment(get_contract): code = """ dummy: public(Bytes[1000000]) -@external +@deploy def __init__(): self.dummy = slice(self.code, 0, 1000000) """ diff --git a/tests/functional/syntax/test_codehash.py b/tests/functional/syntax/test_codehash.py index c2d9a2e274..8aada22da7 100644 --- a/tests/functional/syntax/test_codehash.py +++ b/tests/functional/syntax/test_codehash.py @@ -11,7 +11,7 @@ def test_get_extcodehash(get_contract, evm_version, optimize): code = """ a: address -@external +@deploy def __init__(): self.a = self diff --git a/tests/functional/syntax/test_constants.py b/tests/functional/syntax/test_constants.py index 57922f28e2..63abf24485 100644 --- a/tests/functional/syntax/test_constants.py +++ b/tests/functional/syntax/test_constants.py @@ -94,7 +94,7 @@ VAL: immutable(uint256) VAL: uint256 -@external +@deploy def __init__(): VAL = 1 """, @@ -106,7 +106,7 @@ def __init__(): VAL: uint256 VAL: immutable(uint256) -@external +@deploy def __init__(): VAL = 1 """, diff --git a/tests/functional/syntax/test_immutables.py b/tests/functional/syntax/test_immutables.py index 1027d9fe66..59fb1a69d9 100644 --- a/tests/functional/syntax/test_immutables.py +++ b/tests/functional/syntax/test_immutables.py @@ -8,7 +8,7 @@ """ VALUE: immutable(uint256) -@external +@deploy def __init__(): pass """, @@ -25,7 +25,7 @@ def get_value() -> uint256: """ VALUE: immutable(uint256) = 3 -@external +@deploy def __init__(): pass """, @@ -33,7 +33,7 @@ def __init__(): """ VALUE: immutable(uint256) -@external +@deploy def __init__(): VALUE = 0 @@ -45,7 +45,7 @@ def set_value(_value: uint256): """ VALUE: immutable(uint256) -@external +@deploy def __init__(_value: uint256): VALUE = _value * 3 VALUE = VALUE + 1 @@ -54,7 +54,7 @@ def __init__(_value: uint256): """ VALUE: immutable(public(uint256)) -@external +@deploy def __init__(_value: uint256): VALUE = _value * 3 """, @@ -85,7 +85,7 @@ def test_compilation_simple_usage(typ): code = f""" VALUE: immutable({typ}) -@external +@deploy def __init__(_value: {typ}): VALUE = _value @@ -103,7 +103,7 @@ def get_value() -> {typ}: """ VALUE: immutable(uint256) -@external +@deploy def __init__(_value: uint256): VALUE = _value * 3 x: uint256 = VALUE + 1 @@ -121,7 +121,7 @@ def test_compilation_success(good_code): """ imm: immutable(uint256) -@external +@deploy def __init__(x: uint256): self.imm = x """, @@ -131,7 +131,7 @@ def __init__(x: uint256): """ imm: immutable(uint256) -@external +@deploy def __init__(x: uint256): x = imm @@ -145,7 +145,7 @@ def report(): """ imm: immutable(uint256) -@external +@deploy def __init__(x: uint256): imm = x @@ -163,7 +163,7 @@ def report(): x: immutable(Foo) -@external +@deploy def __init__(): x = Foo({a:1}) diff --git a/tests/functional/syntax/test_init.py b/tests/functional/syntax/test_init.py new file mode 100644 index 0000000000..389b5ad681 --- /dev/null +++ b/tests/functional/syntax/test_init.py @@ -0,0 +1,64 @@ +import pytest + +from vyper.compiler import compile_code +from vyper.exceptions import FunctionDeclarationException + +good_list = [ + """ +@deploy +def __init__(): + pass + """, + """ +@deploy +@payable +def __init__(): + pass + """, + """ +counter: uint256 +SOME_IMMUTABLE: immutable(uint256) + +@deploy +def __init__(): + SOME_IMMUTABLE = 5 + self.counter = 1 + """, +] + + +@pytest.mark.parametrize("code", good_list) +def test_good_init_funcs(code): + assert compile_code(code) is not None + + +fail_list = [ + """ +@internal +def __init__(): + pass + """, + """ +@deploy +@view +def __init__(): + pass + """, + """ +@deploy +@pure +def __init__(): + pass + """, + """ +@deploy +def some_function(): # for now, only __init__() functions can be marked @deploy + pass + """, +] + + +@pytest.mark.parametrize("code", fail_list) +def test_bad_init_funcs(code): + with pytest.raises(FunctionDeclarationException): + compile_code(code) diff --git a/tests/functional/syntax/test_interfaces.py b/tests/functional/syntax/test_interfaces.py index 584e497534..a07ec4e3dc 100644 --- a/tests/functional/syntax/test_interfaces.py +++ b/tests/functional/syntax/test_interfaces.py @@ -304,7 +304,7 @@ def some_func(): nonpayable my_interface: MyInterface[3] idx: uint256 -@external +@deploy def __init__(): self.my_interface[self.idx] = MyInterface(empty(address)) """, @@ -348,7 +348,7 @@ def foo() -> uint256: view foo: public(immutable(uint256)) -@external +@deploy def __init__(x: uint256): foo = x """, diff --git a/tests/functional/syntax/test_public.py b/tests/functional/syntax/test_public.py index 71bff753f4..217fcea998 100644 --- a/tests/functional/syntax/test_public.py +++ b/tests/functional/syntax/test_public.py @@ -10,7 +10,7 @@ x: public(constant(int128)) = 0 y: public(immutable(int128)) -@external +@deploy def __init__(): y = 0 """, diff --git a/tests/functional/syntax/test_tuple_assign.py b/tests/functional/syntax/test_tuple_assign.py index 49b63ee614..bb23804e30 100644 --- a/tests/functional/syntax/test_tuple_assign.py +++ b/tests/functional/syntax/test_tuple_assign.py @@ -92,7 +92,7 @@ def test(a: bytes32) -> (bytes32, uint256, int128): """ B: immutable(uint256) -@external +@deploy def __init__(b: uint256): B = b diff --git a/tests/unit/ast/test_ast_dict.py b/tests/unit/ast/test_ast_dict.py index 20390f3d5e..9fec61cb90 100644 --- a/tests/unit/ast/test_ast_dict.py +++ b/tests/unit/ast/test_ast_dict.py @@ -109,16 +109,6 @@ def foo() -> uint256: "node_id": 9, "src": "48:15:0", "ast_type": "ImplementsDecl", - "target": { - "col_offset": 0, - "end_col_offset": 10, - "node_id": 10, - "src": "48:10:0", - "ast_type": "Name", - "end_lineno": 5, - "lineno": 5, - "id": "implements", - }, "end_lineno": 5, "lineno": 5, } diff --git a/tests/unit/cli/storage_layout/test_storage_layout.py b/tests/unit/cli/storage_layout/test_storage_layout.py index 1aa8901881..f0ee25f747 100644 --- a/tests/unit/cli/storage_layout/test_storage_layout.py +++ b/tests/unit/cli/storage_layout/test_storage_layout.py @@ -56,7 +56,7 @@ def test_storage_and_immutables_layout(): SYMBOL: immutable(String[32]) DECIMALS: immutable(uint8) -@external +@deploy def __init__(): SYMBOL = "VYPR" DECIMALS = 18 @@ -72,3 +72,251 @@ def __init__(): out = compile_code(code, output_formats=["layout"]) assert out["layout"] == expected_layout + + +def test_storage_layout_module(make_input_bundle): + lib1 = """ +supply: uint256 +SYMBOL: immutable(String[32]) +DECIMALS: immutable(uint8) + +@deploy +def __init__(): + SYMBOL = "VYPR" + DECIMALS = 18 + """ + code = """ +import lib1 as a_library + +counter: uint256 +some_immutable: immutable(DynArray[uint256, 10]) + +counter2: uint256 + +initializes: a_library + +@deploy +def __init__(): + some_immutable = [1, 2, 3] + a_library.__init__() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + expected_layout = { + "code_layout": { + "some_immutable": {"length": 352, "offset": 0, "type": "DynArray[uint256, 10]"}, + "a_library": { + "DECIMALS": {"length": 32, "offset": 416, "type": "uint8"}, + "SYMBOL": {"length": 64, "offset": 352, "type": "String[32]"}, + }, + }, + "storage_layout": { + "counter": {"slot": 0, "type": "uint256"}, + "counter2": {"slot": 1, "type": "uint256"}, + "a_library": {"supply": {"slot": 2, "type": "uint256"}}, + }, + } + + out = compile_code(code, input_bundle=input_bundle, output_formats=["layout"]) + assert out["layout"] == expected_layout + + +def test_storage_layout_module2(make_input_bundle): + # test module storage layout, but initializes is in a different order + lib1 = """ +supply: uint256 +SYMBOL: immutable(String[32]) +DECIMALS: immutable(uint8) + +@deploy +def __init__(): + SYMBOL = "VYPR" + DECIMALS = 18 + """ + code = """ +import lib1 as a_library + +counter: uint256 +some_immutable: immutable(DynArray[uint256, 10]) + +initializes: a_library + +counter2: uint256 + +@deploy +def __init__(): + a_library.__init__() + some_immutable = [1, 2, 3] + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + expected_layout = { + "code_layout": { + "some_immutable": {"length": 352, "offset": 0, "type": "DynArray[uint256, 10]"}, + "a_library": { + "SYMBOL": {"length": 64, "offset": 352, "type": "String[32]"}, + "DECIMALS": {"length": 32, "offset": 416, "type": "uint8"}, + }, + }, + "storage_layout": { + "counter": {"slot": 0, "type": "uint256"}, + "a_library": {"supply": {"slot": 1, "type": "uint256"}}, + "counter2": {"slot": 2, "type": "uint256"}, + }, + } + + out = compile_code(code, input_bundle=input_bundle, output_formats=["layout"]) + assert out["layout"] == expected_layout + + +def test_storage_layout_module_uses(make_input_bundle): + # test module storage layout, with initializes/uses + lib1 = """ +supply: uint256 +SYMBOL: immutable(String[32]) +DECIMALS: immutable(uint8) + +@deploy +def __init__(): + SYMBOL = "VYPR" + DECIMALS = 18 + """ + lib2 = """ +import lib1 + +uses: lib1 + +storage_variable: uint256 +immutable_variable: immutable(uint256) + +@deploy +def __init__(s: uint256): + immutable_variable = s + +@internal +def decimals() -> uint8: + return lib1.DECIMALS + """ + code = """ +import lib1 as a_library +import lib2 + +counter: uint256 +some_immutable: immutable(DynArray[uint256, 10]) + +# for fun: initialize lib2 in front of lib1 +initializes: lib2[lib1 := a_library] + +counter2: uint256 + +initializes: a_library + +@deploy +def __init__(): + a_library.__init__() + some_immutable = [1, 2, 3] + + lib2.__init__(17) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + expected_layout = { + "code_layout": { + "some_immutable": {"length": 352, "offset": 0, "type": "DynArray[uint256, 10]"}, + "lib2": {"immutable_variable": {"length": 32, "offset": 352, "type": "uint256"}}, + "a_library": { + "SYMBOL": {"length": 64, "offset": 384, "type": "String[32]"}, + "DECIMALS": {"length": 32, "offset": 448, "type": "uint8"}, + }, + }, + "storage_layout": { + "counter": {"slot": 0, "type": "uint256"}, + "lib2": {"storage_variable": {"slot": 1, "type": "uint256"}}, + "counter2": {"slot": 2, "type": "uint256"}, + "a_library": {"supply": {"slot": 3, "type": "uint256"}}, + }, + } + + out = compile_code(code, input_bundle=input_bundle, output_formats=["layout"]) + assert out["layout"] == expected_layout + + +def test_storage_layout_module_nested_initializes(make_input_bundle): + # test module storage layout, with initializes in an imported module + lib1 = """ +supply: uint256 +SYMBOL: immutable(String[32]) +DECIMALS: immutable(uint8) + +@deploy +def __init__(): + SYMBOL = "VYPR" + DECIMALS = 18 + """ + lib2 = """ +import lib1 + +initializes: lib1 + +storage_variable: uint256 +immutable_variable: immutable(uint256) + +@deploy +def __init__(s: uint256): + immutable_variable = s + lib1.__init__() + +@internal +def decimals() -> uint8: + return lib1.DECIMALS + """ + code = """ +import lib1 as a_library +import lib2 + +counter: uint256 +some_immutable: immutable(DynArray[uint256, 10]) + +# for fun: initialize lib2 in front of lib1 +initializes: lib2 + +counter2: uint256 + +uses: a_library + +@deploy +def __init__(): + some_immutable = [1, 2, 3] + + lib2.__init__(17) + +@external +def foo() -> uint256: + return a_library.supply + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + expected_layout = { + "code_layout": { + "some_immutable": {"length": 352, "offset": 0, "type": "DynArray[uint256, 10]"}, + "lib2": { + "lib1": { + "SYMBOL": {"length": 64, "offset": 352, "type": "String[32]"}, + "DECIMALS": {"length": 32, "offset": 416, "type": "uint8"}, + }, + "immutable_variable": {"length": 32, "offset": 448, "type": "uint256"}, + }, + }, + "storage_layout": { + "counter": {"slot": 0, "type": "uint256"}, + "lib2": { + "lib1": {"supply": {"slot": 1, "type": "uint256"}}, + "storage_variable": {"slot": 2, "type": "uint256"}, + }, + "counter2": {"slot": 3, "type": "uint256"}, + }, + } + + out = compile_code(code, input_bundle=input_bundle, output_formats=["layout"]) + assert out["layout"] == expected_layout diff --git a/tests/unit/compiler/asm/test_asm_optimizer.py b/tests/unit/compiler/asm/test_asm_optimizer.py index b2851e908a..ce32249202 100644 --- a/tests/unit/compiler/asm/test_asm_optimizer.py +++ b/tests/unit/compiler/asm/test_asm_optimizer.py @@ -20,7 +20,7 @@ def runtime_only(): def bar(): self.runtime_only() -@external +@deploy def __init__(): self.ctor_only() """, @@ -44,7 +44,7 @@ def ctor_only(): def bar(): self.foo() -@external +@deploy def __init__(): self.ctor_only() """, @@ -65,7 +65,7 @@ def runtime_only(): def bar(): self.runtime_only() -@external +@deploy def __init__(): self.ctor_only() """, @@ -73,6 +73,9 @@ def __init__(): # check dead code eliminator works on unreachable functions +# CMC 2024-02-05 this is not really the asm eliminator anymore, +# it happens during function code generation in module.py. so we don't +# need to test this using asm anymore. @pytest.mark.parametrize("code", codes) def test_dead_code_eliminator(code): c = CompilerData(code, settings=Settings(optimize=OptimizationLevel.NONE)) @@ -88,20 +91,9 @@ def test_dead_code_eliminator(code): assert any(ctor_only in instr for instr in initcode_asm) assert all(runtime_only not in instr for instr in initcode_asm) - # all labels should be in unoptimized runtime asm - for s in (ctor_only, runtime_only): - assert any(s in instr for instr in runtime_asm) - - c = CompilerData(code, settings=Settings(optimize=OptimizationLevel.GAS)) - initcode_asm = [i for i in c.assembly if isinstance(i, str)] - runtime_asm = [i for i in c.assembly_runtime if isinstance(i, str)] - - # ctor only label should not be in runtime code + assert any(runtime_only in instr for instr in runtime_asm) assert all(ctor_only not in instr for instr in runtime_asm) - # runtime only label should not be in initcode asm - assert all(runtime_only not in instr for instr in initcode_asm) - def test_library_code_eliminator(make_input_bundle): library = """ diff --git a/tests/unit/compiler/test_bytecode_runtime.py b/tests/unit/compiler/test_bytecode_runtime.py index 613ee4d2b8..64cee3a75c 100644 --- a/tests/unit/compiler/test_bytecode_runtime.py +++ b/tests/unit/compiler/test_bytecode_runtime.py @@ -35,7 +35,7 @@ def foo5(): has_immutables = """ A_GOOD_PRIME: public(immutable(uint256)) -@external +@deploy def __init__(): A_GOOD_PRIME = 967 """ diff --git a/tests/unit/semantics/test_storage_slots.py b/tests/unit/semantics/test_storage_slots.py index ea2b2fe559..3620ef64b9 100644 --- a/tests/unit/semantics/test_storage_slots.py +++ b/tests/unit/semantics/test_storage_slots.py @@ -25,7 +25,7 @@ h: public(int256[1]) -@external +@deploy def __init__(): self.a = StructOne({a: "ok", b: [4,5,6]}) self.b = [7, 8] @@ -110,6 +110,6 @@ def test_allocator_overflow(get_contract): """ with pytest.raises( StorageLayoutException, - match=f"Invalid storage slot for var y, tried to allocate slots 1 through {2**256}", + match=f"Invalid storage slot, tried to allocate slots 1 through {2**256}", ): get_contract(code) diff --git a/vyper/ast/__init__.py b/vyper/ast/__init__.py index bc08626b59..0ae93e9710 100644 --- a/vyper/ast/__init__.py +++ b/vyper/ast/__init__.py @@ -5,7 +5,7 @@ from . import nodes, validation from .natspec import parse_natspec -from .nodes import compare_nodes +from .nodes import compare_nodes, as_tuple from .utils import ast_to_dict from .parse import parse_to_ast, parse_to_ast_with_settings @@ -15,6 +15,5 @@ ): setattr(sys.modules[__name__], name, obj) - # required to avoid circular dependency from . import expansion # noqa: E402 diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index 84429501e1..5ad465a1f1 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -182,13 +182,9 @@ loop_variable: NAME ":" type loop_iterator: _expr for_stmt: "for" loop_variable "in" loop_iterator ":" body -// ternary operator -ternary: _expr "if" _expr "else" _expr - // Expressions _expr: operation | dict - | ternary get_item: (variable_access | list) "[" _expr "]" get_attr: variable_access "." NAME @@ -214,7 +210,15 @@ dict: "{" "}" | "{" (NAME ":" _expr) ("," (NAME ":" _expr))* [","] "}" // See https://docs.python.org/3/reference/expressions.html#operator-precedence // NOTE: The recursive cycle here helps enforce operator precedence // Precedence goes up the lower down you go -?operation: bool_or +?operation: assignment_expr + +// "walrus" operator +?assignment_expr: ternary + | NAME ":=" assignment_expr + +// ternary operator +?ternary: bool_or + | ternary "if" ternary "else" ternary _AND: "and" _OR: "or" diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 054145d33b..c4bce814a4 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -83,8 +83,20 @@ def get_node( if ast_struct["value"] is not None: _raise_syntax_exc("`implements` cannot have a value assigned", ast_struct) ast_struct["ast_type"] = "ImplementsDecl" + + # Replace "uses:" `AnnAssign` nodes with `UsesDecl` + elif getattr(ast_struct["target"], "id", None) == "uses": + if ast_struct["value"] is not None: + _raise_syntax_exc("`uses` cannot have a value assigned", ast_struct) + ast_struct["ast_type"] = "UsesDecl" + + # Replace "initializes:" `AnnAssign` nodes with `InitializesDecl` + elif getattr(ast_struct["target"], "id", None) == "initializes": + if ast_struct["value"] is not None: + _raise_syntax_exc("`initializes` cannot have a value assigned", ast_struct) + ast_struct["ast_type"] = "InitializesDecl" + # Replace state and local variable declarations `AnnAssign` with `VariableDecl` - # Parent node is required for context to determine whether replacement should happen. else: ast_struct["ast_type"] = "VariableDecl" @@ -730,6 +742,20 @@ def is_terminus(self): return self.value.is_terminus +class NamedExpr(Stmt): + __slots__ = ("target", "value") + + def validate(self): + # module[dep1 := dep2] + + # XXX: better error messages + if not isinstance(self.target, Name): + raise StructureException("not a Name") + + if not isinstance(self.value, Name): + raise StructureException("not a Name") + + class Log(Stmt): __slots__ = ("value",) @@ -756,6 +782,11 @@ class StructDef(TopLevel): class ExprNode(VyperNode): __slots__ = ("_expr_info",) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._expr_info = None + class Constant(ExprNode): # inherited class for all simple constant node types @@ -1383,17 +1414,13 @@ class ImplementsDecl(Stmt): """ An `implements` declaration. - Excludes `simple` and `value` attributes from Python `AnnAssign` node. - Attributes ---------- - target : Name - Name node for the `implements` keyword annotation : Name Name node for the interface to be implemented """ - __slots__ = ("target", "annotation") + __slots__ = ("annotation",) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -1402,6 +1429,72 @@ def __init__(self, *args, **kwargs): raise StructureException("invalid implements", self.annotation) +def as_tuple(node: VyperNode): + """ + Convenience function for some AST nodes which allow either a Tuple + or single elements. Returns a python tuple of AST nodes. + """ + if isinstance(node, Tuple): + return node.elements + else: + return (node,) + + +class UsesDecl(Stmt): + """ + A `uses` declaration. + + Attributes + ---------- + annotation : Name | Attribute | Tuple + The module(s) which this uses + """ + + __slots__ = ("annotation",) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + items = as_tuple(self.annotation) + for item in items: + if not isinstance(item, (Name, Attribute)): + raise StructureException("invalid uses", item) + + +class InitializesDecl(Stmt): + """ + An `initializes` declaration. + + Attributes + ---------- + annotation : Name | Attribute | Subscript + An imported module which this module initializes + """ + + __slots__ = ("annotation",) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + module_ref = self.annotation + if isinstance(module_ref, Subscript): + dependencies = as_tuple(module_ref.slice) + module_ref = module_ref.value + + for item in dependencies: + if not isinstance(item, NamedExpr): + raise StructureException( + "invalid dependency (hint: should be [dependency := dependency]", item + ) + if not isinstance(item.target, (Name, Attribute)): + raise StructureException("invalid module", item.target) + if not isinstance(item.value, (Name, Attribute)): + raise StructureException("invalid module", item.target) + + if not isinstance(module_ref, (Name, Attribute)): + raise StructureException("invalid module", module_ref) + + class If(Stmt): __slots__ = ("test", "body", "orelse") diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index f71ed67821..7f863a8db9 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -101,7 +101,8 @@ class StructDef(VyperNode): body: list = ... name: str = ... -class ExprNode(VyperNode): ... +class ExprNode(VyperNode): + _expr_info: Any = ... class Constant(VyperNode): value: Any = ... @@ -145,19 +146,19 @@ class Name(VyperNode): _type: str = ... class Expr(VyperNode): - value: VyperNode = ... + value: ExprNode = ... class UnaryOp(ExprNode): op: VyperNode = ... - operand: VyperNode = ... + operand: ExprNode = ... class USub(VyperNode): ... class Not(VyperNode): ... class BinOp(ExprNode): - left: VyperNode = ... op: VyperNode = ... - right: VyperNode = ... + left: ExprNode = ... + right: ExprNode = ... class Add(VyperNode): ... class Sub(VyperNode): ... @@ -173,15 +174,15 @@ class BitXor(VyperNode): ... class BoolOp(ExprNode): op: VyperNode = ... - values: list[VyperNode] = ... + values: list[ExprNode] = ... class And(VyperNode): ... class Or(VyperNode): ... class Compare(ExprNode): op: VyperNode = ... - left: VyperNode = ... - right: VyperNode = ... + left: ExprNode = ... + right: ExprNode = ... class Eq(VyperNode): ... class NotEq(VyperNode): ... @@ -195,13 +196,13 @@ class NotIn(VyperNode): ... class Call(ExprNode): args: list = ... keywords: list = ... - func: VyperNode = ... + func: ExprNode = ... class keyword(VyperNode): ... class Attribute(VyperNode): attr: str = ... - value: VyperNode = ... + value: ExprNode = ... class Subscript(VyperNode): slice: VyperNode = ... @@ -224,8 +225,8 @@ class VariableDecl(VyperNode): class AugAssign(VyperNode): op: VyperNode = ... - target: VyperNode = ... - value: VyperNode = ... + target: ExprNode = ... + value: ExprNode = ... class Raise(VyperNode): ... class Assert(VyperNode): ... @@ -245,6 +246,12 @@ class ImplementsDecl(VyperNode): target: Name = ... annotation: Name = ... +class UsesDecl(VyperNode): + annotation: VyperNode = ... + +class InitializesDecl(VyperNode): + annotation: VyperNode = ... + class If(VyperNode): body: list = ... orelse: list = ... @@ -254,6 +261,10 @@ class IfExp(ExprNode): body: ExprNode = ... orelse: ExprNode = ... +class NamedExpr(ExprNode): + target: Name = ... + value: ExprNode = ... + class For(VyperNode): target: ExprNode iter: ExprNode diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index fc99af901b..a10a840da0 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -278,8 +278,8 @@ def visit_For(self, node): # specific error message than "invalid type annotation" raise SyntaxException( "missing type annotation\n\n" - "(hint: did you mean something like " - f"`for {node.target.id}: uint256 in ...`?)\n", + " (hint: did you mean something like " + f"`for {node.target.id}: uint256 in ...`?)", self._source_code, node.lineno, node.col_offset, diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index d2aefb2fd4..6e6cf4c662 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -85,13 +85,16 @@ class BuiltinFunctionT(VyperType): _kwargs: dict[str, KwargSettings] = {} _modifiability: Modifiability = Modifiability.MODIFIABLE _return_type: Optional[VyperType] = None + _equality_attrs = ("_id",) _is_terminus = False - # helper function to deal with TYPE_DEFINITIONs + @property + def modifiability(self): + return self._modifiability + + # helper function to deal with TYPE_Ts def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> None: - # TODO using "TYPE_DEFINITION" is a kludge in derived classes, - # refactor me. - if expected_type == "TYPE_DEFINITION": + if TYPE_T.any().compare_type(expected_type): # try to parse the type - call type_from_annotation # for its side effects (will throw if is not a type) type_from_annotation(arg) @@ -130,7 +133,7 @@ def _validate_arg_types(self, node: vy_ast.Call) -> None: get_exact_type_from_node(arg) def check_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifiability) -> bool: - return self._modifiability >= modifiability + return self._modifiability <= modifiability def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: self._validate_arg_types(node) diff --git a/vyper/builtins/_utils.py b/vyper/builtins/_utils.py index 72b05f15e3..3fad225b48 100644 --- a/vyper/builtins/_utils.py +++ b/vyper/builtins/_utils.py @@ -1,7 +1,7 @@ from vyper.ast import parse_to_ast from vyper.codegen.context import Context from vyper.codegen.stmt import parse_body -from vyper.semantics.analysis.local import FunctionNodeVisitor +from vyper.semantics.analysis.local import FunctionAnalyzer from vyper.semantics.namespace import Namespace, override_global_namespace from vyper.semantics.types.function import ContractFunctionT, FunctionVisibility, StateMutability from vyper.semantics.types.module import ModuleT @@ -25,9 +25,7 @@ def generate_inline_function(code, variables, variables_2, memory_allocator): ast_code.body[0]._metadata["func_type"] = ContractFunctionT( "sqrt_builtin", [], [], None, FunctionVisibility.INTERNAL, StateMutability.NONPAYABLE ) - # The FunctionNodeVisitor's constructor performs semantic checks - # annotate the AST as side effects. - analyzer = FunctionNodeVisitor(ast_code, ast_code.body[0], namespace) + analyzer = FunctionAnalyzer(ast_code, ast_code.body[0], namespace) analyzer.analyze() new_context = Context( diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 50ab4dacd8..7575f4d77e 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -113,10 +113,7 @@ class TypenameFoldedFunctionT(FoldedFunctionT): # Base class for builtin functions that: # (1) take a typename as the only argument; and # (2) should always be folded. - - # "TYPE_DEFINITION" is a placeholder value for a type definition string, and - # will be replaced by a `TypeTypeDefinition` object in `infer_arg_types`. - _inputs = [("typename", "TYPE_DEFINITION")] + _inputs = [("typename", TYPE_T.any())] def fetch_call_return(self, node): type_ = self.infer_arg_types(node)[0].typedef @@ -711,7 +708,7 @@ def build_IR(self, expr, args, kwargs, context): class MethodID(FoldedFunctionT): _id = "method_id" _inputs = [("value", StringT.any())] - _kwargs = {"output_type": KwargSettings("TYPE_DEFINITION", BytesT(4))} + _kwargs = {"output_type": KwargSettings(TYPE_T.any(), BytesT(4))} def _try_fold(self, node): validate_call_args(node, 1, ["output_type"]) @@ -848,10 +845,7 @@ def _storage_element_getter(index): class Extract32(BuiltinFunctionT): _id = "extract32" _inputs = [("b", BytesT.any()), ("start", IntegerT.unsigneds())] - # "TYPE_DEFINITION" is a placeholder value for a type definition string, and - # will be replaced by a `TYPE_T` object in `infer_kwarg_types` - # (note that it is ignored in _validate_arg_types) - _kwargs = {"output_type": KwargSettings("TYPE_DEFINITION", BYTES32_T)} + _kwargs = {"output_type": KwargSettings(TYPE_T.any(), BYTES32_T)} def fetch_call_return(self, node): self._validate_arg_types(node) @@ -1976,18 +1970,22 @@ def build_IR(self, expr, args, kwargs, context): class UnsafeAdd(_UnsafeMath): + _id = "unsafe_add" op = "add" class UnsafeSub(_UnsafeMath): + _id = "unsafe_sub" op = "sub" class UnsafeMul(_UnsafeMath): + _id = "unsafe_mul" op = "mul" class UnsafeDiv(_UnsafeMath): + _id = "unsafe_div" op = "div" @@ -2474,7 +2472,7 @@ def build_IR(self, expr, args, kwargs, context): class ABIDecode(BuiltinFunctionT): _id = "_abi_decode" - _inputs = [("data", BytesT.any()), ("output_type", "TYPE_DEFINITION")] + _inputs = [("data", BytesT.any()), ("output_type", TYPE_T.any())] _kwargs = {"unwrap_tuple": KwargSettings(BoolT(), True, require_literal=True)} def fetch_call_return(self, node): diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py index 4f644841f4..af01c5b504 100644 --- a/vyper/codegen/context.py +++ b/vyper/codegen/context.py @@ -44,7 +44,7 @@ def __repr__(self): return f"VariableRecord({ret})" -# Contains arguments, variables, etc +# compilation context for a function class Context: def __init__( self, @@ -59,19 +59,12 @@ def __init__( # In-memory variables, in the form (name, memory location, type) self.vars = vars_ or {} - # Global variables, in the form (name, storage location, type) - self.globals = module_ctx.variables - # Variables defined in for loops, e.g. for i in range(6): ... self.forvars = forvars or {} # Is the function constant? self.constancy = constancy - # Whether body is currently in an assert statement - # XXX: dead, never set to True - self.in_assertion = False - # Whether we are currently parsing a range expression self.in_range_expr = False @@ -87,6 +80,10 @@ def __init__( # Not intended to be accessed directly self.memory_allocator = memory_allocator + # save the starting memory location so we can find out (later) + # how much memory this function uses. + self.starting_memory = memory_allocator.next_mem + # Incremented values, used for internal IDs self._internal_var_iter = 0 self._scope_id_iter = 0 @@ -95,7 +92,7 @@ def __init__( self.is_ctor_context = is_ctor_context def is_constant(self): - return self.constancy is Constancy.Constant or self.in_assertion or self.in_range_expr + return self.constancy is Constancy.Constant or self.in_range_expr def check_is_not_constant(self, err, expr): if self.is_constant(): @@ -250,9 +247,7 @@ def lookup_var(self, varname): # Pretty print constancy for error messages def pp_constancy(self): - if self.in_assertion: - return "an assertion" - elif self.in_range_expr: + if self.in_range_expr: return "a range expression" elif self.constancy == Constancy.Constant: return "a constant function" diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index c3215f8c16..1a090ac316 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -3,9 +3,18 @@ from vyper.codegen.ir_node import Encoding, IRnode from vyper.compiler.settings import OptimizationLevel -from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT +from vyper.evm.address_space import ( + CALLDATA, + DATA, + IMMUTABLES, + MEMORY, + STORAGE, + TRANSIENT, + AddrSpace, +) from vyper.evm.opcodes import version_check from vyper.exceptions import CompilerPanic, TypeCheckFailure, TypeMismatch +from vyper.semantics.data_locations import DataLocation from vyper.semantics.types import ( AddressT, BoolT, @@ -100,6 +109,36 @@ def _codecopy_gas_bound(num_bytes): return GAS_COPY_WORD * ceil32(num_bytes) // 32 +def data_location_to_address_space(s: DataLocation, is_ctor_ctx: bool) -> AddrSpace: + if s == DataLocation.MEMORY: + return MEMORY + if s == DataLocation.STORAGE: + return STORAGE + if s == DataLocation.TRANSIENT: + return TRANSIENT + if s == DataLocation.CODE: + if is_ctor_ctx: + return IMMUTABLES + return DATA + + raise CompilerPanic("unreachable!") # pragma: nocover + + +def address_space_to_data_location(s: AddrSpace) -> DataLocation: + if s == MEMORY: + return DataLocation.MEMORY + if s == STORAGE: + return DataLocation.STORAGE + if s == TRANSIENT: + return DataLocation.TRANSIENT + if s in (IMMUTABLES, DATA): + return DataLocation.CODE + if s == CALLDATA: + return DataLocation.CALLDATA + + raise CompilerPanic("unreachable!") # pragma: nocover + + # Copy byte array word-for-word (including layout) # TODO make this a private function def make_byte_array_copier(dst, src): @@ -482,14 +521,10 @@ def _get_element_ptr_tuplelike(parent, key): return _getelemptr_abi_helper(parent, member_t, ofst) - if parent.location.word_addressable: - for i in range(index): - ofst += typ.member_types[attrs[i]].storage_size_in_words - elif parent.location.byte_addressable: - for i in range(index): - ofst += typ.member_types[attrs[i]].memory_bytes_required - else: - raise CompilerPanic(f"bad location {parent.location}") # pragma: notest + data_location = address_space_to_data_location(parent.location) + for i in range(index): + t = typ.member_types[attrs[i]] + ofst += t.get_size_in(data_location) return IRnode.from_list( add_ofst(parent, ofst), @@ -550,12 +585,8 @@ def _get_element_ptr_array(parent, key, array_bounds_check): return _getelemptr_abi_helper(parent, subtype, ofst) - if parent.location.word_addressable: - element_size = subtype.storage_size_in_words - elif parent.location.byte_addressable: - element_size = subtype.memory_bytes_required - else: - raise CompilerPanic("unreachable") # pragma: notest + data_location = address_space_to_data_location(parent.location) + element_size = subtype.get_size_in(data_location) ofst = _mul(ix, element_size) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index f4c7948382..335cfefb87 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -6,6 +6,7 @@ from vyper.codegen import external_call, self_call from vyper.codegen.core import ( clamp, + data_location_to_address_space, ensure_in_memory, get_dyn_array_count, get_element_ptr, @@ -23,7 +24,7 @@ ) from vyper.codegen.ir_node import IRnode from vyper.codegen.keccak256_helper import keccak256_helper -from vyper.evm.address_space import DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT +from vyper.evm.address_space import MEMORY from vyper.evm.opcodes import version_check from vyper.exceptions import ( CodegenPanic, @@ -185,26 +186,24 @@ def parse_Name(self): ret._referenced_variables = {var} return ret - # TODO: use self.expr._expr_info - elif self.expr.id in self.context.globals: - varinfo = self.context.globals[self.expr.id] - + elif (varinfo := self.expr._expr_info.var_info) is not None: if varinfo.is_constant: return Expr.parse_value_expr(varinfo.decl_node.value, self.context) assert varinfo.is_immutable, "not an immutable!" - ofst = varinfo.position.offset + mutable = self.context.is_ctor_context - if self.context.is_ctor_context: - mutable = True - location = IMMUTABLES - else: - mutable = False - location = DATA + location = data_location_to_address_space( + varinfo.location, self.context.is_ctor_context + ) ret = IRnode.from_list( - ofst, typ=varinfo.typ, location=location, annotation=self.expr.id, mutable=mutable + varinfo.position.position, + typ=varinfo.typ, + location=location, + annotation=self.expr.id, + mutable=mutable, ) ret._referenced_variables = {varinfo} return ret @@ -265,9 +264,13 @@ def parse_Attribute(self): return IRnode.from_list(["~selfcode"], typ=BytesT(0)) return IRnode.from_list(["~extcode", addr], typ=BytesT(0)) # self.x: global attribute - elif isinstance(self.expr.value, vy_ast.Name) and self.expr.value.id == "self": - varinfo = self.context.globals[self.expr.attr] - location = TRANSIENT if varinfo.is_transient else STORAGE + elif (varinfo := self.expr._expr_info.var_info) is not None: + if varinfo.is_constant: + return Expr.parse_value_expr(varinfo.decl_node.value, self.context) + + location = data_location_to_address_space( + varinfo.location, self.context.is_ctor_context + ) ret = IRnode.from_list( varinfo.position.position, @@ -700,7 +703,7 @@ def parse_Call(self): return pop_dyn_array(darray, return_popped_item=True) if isinstance(func_type, ContractFunctionT): - if func_type.is_internal: + if func_type.is_internal or func_type.is_constructor: return self_call.ir_for_self_call(self.expr, self.context) else: return external_call.ir_for_external_call(self.expr, self.context) diff --git a/vyper/codegen/function_definitions/__init__.py b/vyper/codegen/function_definitions/__init__.py index 94617bef35..254b4df72c 100644 --- a/vyper/codegen/function_definitions/__init__.py +++ b/vyper/codegen/function_definitions/__init__.py @@ -1 +1,4 @@ -from .common import FuncIR, generate_ir_for_function # noqa +from .external_function import generate_ir_for_external_function +from .internal_function import generate_ir_for_internal_function + +__all__ = [generate_ir_for_internal_function, generate_ir_for_external_function] # type: ignore diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index 5877ff3d13..d017ba7b81 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -2,17 +2,14 @@ from functools import cached_property from typing import Optional -import vyper.ast as vy_ast from vyper.codegen.context import Constancy, Context -from vyper.codegen.function_definitions.external_function import generate_ir_for_external_function -from vyper.codegen.function_definitions.internal_function import generate_ir_for_internal_function from vyper.codegen.ir_node import IRnode from vyper.codegen.memory_allocator import MemoryAllocator -from vyper.exceptions import CompilerPanic +from vyper.evm.opcodes import version_check from vyper.semantics.types import VyperType -from vyper.semantics.types.function import ContractFunctionT +from vyper.semantics.types.function import ContractFunctionT, StateMutability from vyper.semantics.types.module import ModuleT -from vyper.utils import MemoryPositions, calc_mem_gas +from vyper.utils import MemoryPositions @dataclass @@ -53,9 +50,11 @@ def ir_identifier(self) -> str: return f"{self.visibility} {function_id} {name}({argz})" def set_frame_info(self, frame_info: FrameInfo) -> None: + # XXX: when can this happen? if self.frame_info is not None: - raise CompilerPanic(f"frame_info already set for {self.func_t}!") - self.frame_info = frame_info + assert frame_info == self.frame_info + else: + self.frame_info = frame_info @property # common entry point for external function with kwargs @@ -64,13 +63,15 @@ def external_function_base_entry_label(self) -> str: return self.ir_identifier + "_common" def internal_function_label(self, is_ctor_context: bool = False) -> str: - assert self.func_t.is_internal, "uh oh, should be internal" - suffix = "_deploy" if is_ctor_context else "_runtime" - return self.ir_identifier + suffix + f = self.func_t + assert f.is_internal or f.is_constructor, "uh oh, should be internal" + if f.is_constructor: + # sanity check - imported init functions only callable from main init + assert is_ctor_context -class FuncIR: - pass + suffix = "_deploy" if is_ctor_context else "_runtime" + return self.ir_identifier + suffix @dataclass @@ -80,7 +81,7 @@ class EntryPointInfo: ir_node: IRnode # the ir for this entry point def __post_init__(self): - # ABI v2 property guaranteed by the spec. + # sanity check ABI v2 properties guaranteed by the spec. # https://docs.soliditylang.org/en/v0.8.21/abi-spec.html#formal-specification-of-the-encoding states: # noqa: E501 # > Note that for any X, len(enc(X)) is a multiple of 32. assert self.min_calldatasize >= 4 @@ -88,34 +89,28 @@ def __post_init__(self): @dataclass -class ExternalFuncIR(FuncIR): +class ExternalFuncIR: entry_points: dict[str, EntryPointInfo] # map from abi sigs to entry points common_ir: IRnode # the "common" code for the function @dataclass -class InternalFuncIR(FuncIR): +class InternalFuncIR: func_ir: IRnode # the code for the function -# TODO: should split this into external and internal ir generation? -def generate_ir_for_function( - code: vy_ast.FunctionDef, module_ctx: ModuleT, is_ctor_context: bool = False -) -> FuncIR: - """ - Parse a function and produce IR code for the function, includes: - - Signature method if statement - - Argument handling - - Clamping and copying of arguments - - Function body - """ - func_t = code._metadata["func_type"] - - # generate _FuncIRInfo +def init_ir_info(func_t: ContractFunctionT): + # initialize IRInfo on the function func_t._ir_info = _FuncIRInfo(func_t) - callees = func_t.called_functions +def initialize_context( + func_t: ContractFunctionT, module_ctx: ModuleT, is_ctor_context: bool = False +): + init_ir_info(func_t) + + # calculate starting frame + callees = func_t.called_functions # we start our function frame from the largest callee frame max_callee_frame_size = 0 for c_func_t in callees: @@ -126,7 +121,7 @@ def generate_ir_for_function( memory_allocator = MemoryAllocator(allocate_start) - context = Context( + return Context( vars_=None, module_ctx=module_ctx, memory_allocator=memory_allocator, @@ -135,38 +130,41 @@ def generate_ir_for_function( is_ctor_context=is_ctor_context, ) - if func_t.is_internal: - ret: FuncIR = InternalFuncIR(generate_ir_for_internal_function(code, func_t, context)) - func_t._ir_info.gas_estimate = ret.func_ir.gas # type: ignore - else: - kwarg_handlers, common = generate_ir_for_external_function(code, func_t, context) - entry_points = { - k: EntryPointInfo(func_t, mincalldatasize, ir_node) - for k, (mincalldatasize, ir_node) in kwarg_handlers.items() - } - ret = ExternalFuncIR(entry_points, common) - # note: this ignores the cost of traversing selector table - func_t._ir_info.gas_estimate = ret.common_ir.gas +def tag_frame_info(func_t, context): frame_size = context.memory_allocator.size_of_mem - MemoryPositions.RESERVED_MEMORY + frame_start = context.starting_memory - frame_info = FrameInfo(allocate_start, frame_size, context.vars) + frame_info = FrameInfo(frame_start, frame_size, context.vars) + func_t._ir_info.set_frame_info(frame_info) - # XXX: when can this happen? - if func_t._ir_info.frame_info is None: - func_t._ir_info.set_frame_info(frame_info) - else: - assert frame_info == func_t._ir_info.frame_info - - if not func_t.is_internal: - # adjust gas estimate to include cost of mem expansion - # frame_size of external function includes all private functions called - # (note: internal functions do not need to adjust gas estimate since - mem_expansion_cost = calc_mem_gas(func_t._ir_info.frame_info.mem_used) # type: ignore - ret.common_ir.add_gas_estimate += mem_expansion_cost # type: ignore - ret.common_ir.passthrough_metadata["func_t"] = func_t # type: ignore - ret.common_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore + return frame_info + + +def get_nonreentrant_lock(func_t): + if not func_t.nonreentrant: + return ["pass"], ["pass"] + + nkey = func_t.reentrancy_key_position.position + + LOAD, STORE = "sload", "sstore" + if version_check(begin="cancun"): + LOAD, STORE = "tload", "tstore" + + if version_check(begin="berlin"): + # any nonzero values would work here (see pricing as of net gas + # metering); these values are chosen so that downgrading to the + # 0,1 scheme (if it is somehow necessary) is safe. + final_value, temp_value = 3, 2 else: - ret.func_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore + final_value, temp_value = 0, 1 + + check_notset = ["assert", ["ne", temp_value, [LOAD, nkey]]] - return ret + if func_t.mutability == StateMutability.VIEW: + return [check_notset], [["seq"]] + + else: + pre = ["seq", check_notset, [STORE, nkey, temp_value]] + post = [STORE, nkey, final_value] + return [pre], [post] diff --git a/vyper/codegen/function_definitions/external_function.py b/vyper/codegen/function_definitions/external_function.py index 65276469e7..b380eab2ce 100644 --- a/vyper/codegen/function_definitions/external_function.py +++ b/vyper/codegen/function_definitions/external_function.py @@ -2,12 +2,19 @@ from vyper.codegen.context import Context, VariableRecord from vyper.codegen.core import get_element_ptr, getpos, make_setter, needs_clamp from vyper.codegen.expr import Expr -from vyper.codegen.function_definitions.utils import get_nonreentrant_lock +from vyper.codegen.function_definitions.common import ( + EntryPointInfo, + ExternalFuncIR, + get_nonreentrant_lock, + initialize_context, + tag_frame_info, +) from vyper.codegen.ir_node import Encoding, IRnode from vyper.codegen.stmt import parse_body from vyper.evm.address_space import CALLDATA, DATA, MEMORY from vyper.semantics.types import TupleT from vyper.semantics.types.function import ContractFunctionT +from vyper.utils import calc_mem_gas # register function args with the local calling context. @@ -51,7 +58,7 @@ def _register_function_args(func_t: ContractFunctionT, context: Context) -> list def _generate_kwarg_handlers( func_t: ContractFunctionT, context: Context -) -> dict[str, tuple[int, IRnode]]: +) -> dict[str, EntryPointInfo]: # generate kwarg handlers. # since they might come in thru calldata or be default, # allocate them in memory and then fill it in based on calldata or default, @@ -126,34 +133,54 @@ def handler_for(calldata_kwargs, default_kwargs): default_kwargs = keyword_args[i:] sig, calldata_min_size, ir_node = handler_for(calldata_kwargs, default_kwargs) - ret[sig] = calldata_min_size, ir_node + assert sig not in ret + ret[sig] = EntryPointInfo(func_t, calldata_min_size, ir_node) sig, calldata_min_size, ir_node = handler_for(keyword_args, []) - ret[sig] = calldata_min_size, ir_node + assert sig not in ret + ret[sig] = EntryPointInfo(func_t, calldata_min_size, ir_node) return ret -def generate_ir_for_external_function(code, func_t, context): +def _adjust_gas_estimate(func_t, common_ir): + # adjust gas estimate to include cost of mem expansion + # frame_size of external function includes all private functions called + # (note: internal functions do not need to adjust gas estimate since + frame_info = func_t._ir_info.frame_info + + mem_expansion_cost = calc_mem_gas(frame_info.mem_used) + common_ir.add_gas_estimate += mem_expansion_cost + func_t._ir_info.gas_estimate = common_ir.gas + + # pass metadata through for venom pipeline: + common_ir.passthrough_metadata["func_t"] = func_t + common_ir.passthrough_metadata["frame_info"] = frame_info + + +def generate_ir_for_external_function(code, compilation_target): # TODO type hints: # def generate_ir_for_external_function( # code: vy_ast.FunctionDef, - # func_t: ContractFunctionT, - # context: Context, + # compilation_target: ModuleT, # ) -> IRnode: """ Return the IR for an external function. Returns IR for the body of the function, handle kwargs and exit the function. Also returns metadata required for `module.py` to construct the selector table. """ + func_t = code._metadata["func_type"] + assert func_t.is_external or func_t.is_constructor # sanity check + + context = initialize_context(func_t, compilation_target, func_t.is_constructor) nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(func_t) # generate handlers for base args and register the variable records handle_base_args = _register_function_args(func_t, context) # generate handlers for kwargs and register the variable records - kwarg_handlers = _generate_kwarg_handlers(func_t, context) + entry_points = _generate_kwarg_handlers(func_t, context) body = ["seq"] # once optional args have been handled, @@ -185,4 +212,8 @@ def generate_ir_for_external_function(code, func_t, context): # besides any kwarg handling func_common_ir = IRnode.from_list(["seq", body, exit_], source_pos=getpos(code)) - return kwarg_handlers, func_common_ir + tag_frame_info(func_t, context) + + _adjust_gas_estimate(func_t, func_common_ir) + + return ExternalFuncIR(entry_points, func_common_ir) diff --git a/vyper/codegen/function_definitions/internal_function.py b/vyper/codegen/function_definitions/internal_function.py index cf01dbdab4..0cf9850b70 100644 --- a/vyper/codegen/function_definitions/internal_function.py +++ b/vyper/codegen/function_definitions/internal_function.py @@ -1,23 +1,25 @@ from vyper import ast as vy_ast -from vyper.codegen.context import Context -from vyper.codegen.function_definitions.utils import get_nonreentrant_lock +from vyper.codegen.function_definitions.common import ( + InternalFuncIR, + get_nonreentrant_lock, + initialize_context, + tag_frame_info, +) from vyper.codegen.ir_node import IRnode from vyper.codegen.stmt import parse_body -from vyper.semantics.types.function import ContractFunctionT def generate_ir_for_internal_function( - code: vy_ast.FunctionDef, func_t: ContractFunctionT, context: Context -) -> IRnode: + code: vy_ast.FunctionDef, module_ctx, is_ctor_context: bool +) -> InternalFuncIR: """ Parse a internal function (FuncDef), and produce full function body. :param func_t: the ContractFunctionT :param code: ast of function - :param context: current calling context + :param compilation_target: current calling context :return: function body in IR """ - # The calling convention is: # Caller fills in argument buffer # Caller provides return address, return buffer on the stack @@ -37,13 +39,19 @@ def generate_ir_for_internal_function( # situation like the following is easy to bork: # x: T[2] = [self.generate_T(), self.generate_T()] - # Get nonreentrant lock + func_t = code._metadata["func_type"] + + # sanity check + assert func_t.is_internal or func_t.is_constructor + + context = initialize_context(func_t, module_ctx, is_ctor_context) for arg in func_t.arguments: # allocate a variable for every arg, setting mutability # to True to allow internal function arguments to be mutable context.new_variable(arg.name, arg.typ, is_mutable=True) + # Get nonreentrant lock nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(func_t) function_entry_label = func_t._ir_info.internal_function_label(context.is_ctor_context) @@ -69,5 +77,13 @@ def generate_ir_for_internal_function( ] ir_node = IRnode.from_list(["seq", body, cleanup_routine]) + + # tag gas estimate and frame info + func_t._ir_info.gas_estimate = ir_node.gas + frame_info = tag_frame_info(func_t, context) + + # pass metadata through for venom pipeline: + ir_node.passthrough_metadata["frame_info"] = frame_info ir_node.passthrough_metadata["func_t"] = func_t - return ir_node + + return InternalFuncIR(ir_node) diff --git a/vyper/codegen/function_definitions/utils.py b/vyper/codegen/function_definitions/utils.py deleted file mode 100644 index f524ec6e88..0000000000 --- a/vyper/codegen/function_definitions/utils.py +++ /dev/null @@ -1,31 +0,0 @@ -from vyper.evm.opcodes import version_check -from vyper.semantics.types.function import StateMutability - - -def get_nonreentrant_lock(func_type): - if not func_type.nonreentrant: - return ["pass"], ["pass"] - - nkey = func_type.reentrancy_key_position.position - - LOAD, STORE = "sload", "sstore" - if version_check(begin="cancun"): - LOAD, STORE = "tload", "tstore" - - if version_check(begin="berlin"): - # any nonzero values would work here (see pricing as of net gas - # metering); these values are chosen so that downgrading to the - # 0,1 scheme (if it is somehow necessary) is safe. - final_value, temp_value = 3, 2 - else: - final_value, temp_value = 0, 1 - - check_notset = ["assert", ["ne", temp_value, [LOAD, nkey]]] - - if func_type.mutability == StateMutability.VIEW: - return [check_notset], [["seq"]] - - else: - pre = ["seq", check_notset, [STORE, nkey, temp_value]] - post = [STORE, nkey, final_value] - return [pre], [post] diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index 98395a6a0c..fef4f23949 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -4,7 +4,10 @@ from vyper.codegen import core, jumptable_utils from vyper.codegen.core import shr -from vyper.codegen.function_definitions import generate_ir_for_function +from vyper.codegen.function_definitions import ( + generate_ir_for_external_function, + generate_ir_for_internal_function, +) from vyper.codegen.ir_node import IRnode from vyper.compiler.settings import _is_debug_mode from vyper.exceptions import CompilerPanic @@ -89,7 +92,7 @@ def _ir_for_fallback_or_ctor(func_ast, *args, **kwargs): callvalue_check = ["assert", ["iszero", "callvalue"]] ret.append(IRnode.from_list(callvalue_check, error_msg="nonpayable check")) - func_ir = generate_ir_for_function(func_ast, *args, **kwargs) + func_ir = generate_ir_for_external_function(func_ast, *args, **kwargs) assert len(func_ir.entry_points) == 1 # add a goto to make the function entry look like other functions @@ -101,7 +104,7 @@ def _ir_for_fallback_or_ctor(func_ast, *args, **kwargs): def _ir_for_internal_function(func_ast, *args, **kwargs): - return generate_ir_for_function(func_ast, *args, **kwargs).func_ir + return generate_ir_for_internal_function(func_ast, *args, **kwargs).func_ir def _generate_external_entry_points(external_functions, module_ctx): @@ -109,7 +112,7 @@ def _generate_external_entry_points(external_functions, module_ctx): sig_of = {} # reverse map from method ids to abi sig for code in external_functions: - func_ir = generate_ir_for_function(code, module_ctx) + func_ir = generate_ir_for_external_function(code, module_ctx) for abi_sig, entry_point in func_ir.entry_points.items(): method_id = method_id_int(abi_sig) assert abi_sig not in entry_points @@ -424,12 +427,13 @@ def _selector_section_linear(external_functions, module_ctx): # take a ModuleT, and generate the runtime and deploy IR def generate_ir_for_module(module_ctx: ModuleT) -> tuple[IRnode, IRnode]: + # XXX: rename `module_ctx` to `compilation_target` # order functions so that each function comes after all of its callees function_defs = _topsort(module_ctx.function_defs) reachable = _globally_reachable_functions(module_ctx.function_defs) runtime_functions = [f for f in function_defs if not _is_constructor(f)] - init_function = next((f for f in function_defs if _is_constructor(f)), None) + init_function = next((f for f in module_ctx.function_defs if _is_constructor(f)), None) internal_functions = [f for f in runtime_functions if _is_internal(f)] @@ -475,24 +479,21 @@ def generate_ir_for_module(module_ctx: ModuleT) -> tuple[IRnode, IRnode]: deploy_code: List[Any] = ["seq"] immutables_len = module_ctx.immutable_section_bytes - if init_function: + if init_function is not None: # cleanly rerun codegen for internal functions with `is_ctor_ctx=True` init_func_t = init_function._metadata["func_type"] ctor_internal_func_irs = [] - internal_functions = [f for f in runtime_functions if _is_internal(f)] - for f in internal_functions: - func_t = f._metadata["func_type"] - if func_t not in init_func_t.reachable_internal_functions: - # unreachable code, delete it - continue - - func_ir = _ir_for_internal_function(f, module_ctx, is_ctor_context=True) + + reachable_from_ctor = init_func_t.reachable_internal_functions + for func_t in reachable_from_ctor: + fn_ast = func_t.ast_def + func_ir = _ir_for_internal_function(fn_ast, module_ctx, is_ctor_context=True) ctor_internal_func_irs.append(func_ir) # generate init_func_ir after callees to ensure they have analyzed # memory usage. # TODO might be cleaner to separate this into an _init_ir helper func - init_func_ir = _ir_for_fallback_or_ctor(init_function, module_ctx, is_ctor_context=True) + init_func_ir = _ir_for_fallback_or_ctor(init_function, module_ctx) # pass the amount of memory allocated for the init function # so that deployment does not clobber while preparing immutables diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 7d4938f287..e6baea75f7 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -144,7 +144,7 @@ def parse_Call(self): return pop_dyn_array(darray, return_popped_item=False) if isinstance(func_type, ContractFunctionT): - if func_type.is_internal: + if func_type.is_internal or func_type.is_constructor: return self_call.ir_for_self_call(self.stmt, self.context) else: return external_call.ir_for_external_call(self.stmt, self.context) diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 5b7decec7b..f7eccdf214 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -152,23 +152,18 @@ def vyper_module(self): return self._generate_ast @cached_property - def _annotated_module(self): - return generate_annotated_ast( - self.vyper_module, self.input_bundle, self.storage_layout_override - ) - - @property def annotated_vyper_module(self) -> vy_ast.Module: - module, storage_layout = self._annotated_module - return module + return generate_annotated_ast(self.vyper_module, self.input_bundle) - @property + @cached_property def storage_layout(self) -> StorageLayout: - module, storage_layout = self._annotated_module - return storage_layout + module_ast = self.annotated_vyper_module + return set_data_positions(module_ast, self.storage_layout_override) @property def global_ctx(self) -> ModuleT: + # ensure storage layout is computed + _ = self.storage_layout return self.annotated_vyper_module._metadata["type"] @cached_property @@ -243,11 +238,7 @@ def blueprint_bytecode(self) -> bytes: return deploy_bytecode + blueprint_bytecode -def generate_annotated_ast( - vyper_module: vy_ast.Module, - input_bundle: InputBundle, - storage_layout_overrides: StorageLayout = None, -) -> tuple[vy_ast.Module, StorageLayout]: +def generate_annotated_ast(vyper_module: vy_ast.Module, input_bundle: InputBundle) -> vy_ast.Module: """ Validates and annotates the Vyper AST. @@ -268,9 +259,7 @@ def generate_annotated_ast( # note: validate_semantics does type inference on the AST validate_semantics(vyper_module, input_bundle) - symbol_tables = set_data_positions(vyper_module, storage_layout_overrides) - - return vyper_module, symbol_tables + return vyper_module def generate_ir_nodes(global_ctx: ModuleT, optimize: OptimizationLevel) -> tuple[IRnode, IRnode]: diff --git a/vyper/evm/address_space.py b/vyper/evm/address_space.py index 85a75c3c23..fcbd4bcf63 100644 --- a/vyper/evm/address_space.py +++ b/vyper/evm/address_space.py @@ -28,14 +28,6 @@ class AddrSpace: # TODO maybe make positional instead of defaulting to None store_op: Optional[str] = None - @property - def word_addressable(self) -> bool: - return self.word_scale == 1 - - @property - def byte_addressable(self) -> bool: - return self.word_scale == 32 - # alternative: # class Memory(AddrSpace): diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 04667aaa59..53ad6f7bb8 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -31,7 +31,7 @@ class _BaseVyperException(Exception): order to display source annotations in the error string. """ - def __init__(self, message="Error Message not found.", *items): + def __init__(self, message="Error Message not found.", *items, hint=None): """ Exception initializer. @@ -47,7 +47,9 @@ def __init__(self, message="Error Message not found.", *items): A single tuple of (lineno, col_offset) is also understood to support the old API, but new exceptions should not use this approach. """ - self.message = message + self._message = message + self._hint = hint + self.lineno = None self.col_offset = None self.annotations = None @@ -77,6 +79,13 @@ def with_annotation(self, *annotations): exc.annotations = annotations return exc + @property + def message(self): + msg = self._message + if self._hint: + msg += f"\n\n (hint: {self._hint})" + return msg + def __str__(self): from vyper import ast as vy_ast from vyper.utils import annotate_source_code @@ -131,7 +140,7 @@ def __str__(self): annotation_list.append(node_msg) annotation_msg = "\n".join(annotation_list) - return f"{self.message}\n{annotation_msg}" + return f"{self.message}\n\n{annotation_msg}" class VyperException(_BaseVyperException): @@ -252,6 +261,14 @@ class ImmutableViolation(VyperException): """Modifying an immutable variable, constant, or definition.""" +class InitializerException(VyperException): + """An issue with initializing/constructing a module""" + + +class BorrowException(VyperException): + """An issue with borrowing/using a module""" + + class StateAccessViolation(VyperException): """Violating the mutability of a function definition.""" @@ -369,7 +386,7 @@ def tag_exceptions(node, fallback_exception_type=CompilerPanic, note=None): except _BaseVyperException as e: if not e.annotations and not e.lineno: tb = e.__traceback__ - raise e.with_annotation(node).with_traceback(tb) + raise e.with_annotation(node).with_traceback(tb) from None raise e from None except Exception as e: tb = e.__traceback__ diff --git a/vyper/semantics/analysis/__init__.py b/vyper/semantics/analysis/__init__.py index 7b52a68e92..e23b2d2aa4 100644 --- a/vyper/semantics/analysis/__init__.py +++ b/vyper/semantics/analysis/__init__.py @@ -1,4 +1,4 @@ from .. import types # break a dependency cycle. -from .module import validate_semantics +from .global_ import validate_semantics __all__ = ["validate_semantics"] diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index bb6d9ad9f7..2086e5f9da 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -1,84 +1,29 @@ import enum -from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Dict, Optional, Union from vyper import ast as vy_ast from vyper.compiler.input_bundle import InputBundle -from vyper.exceptions import ( - CompilerPanic, - ImmutableViolation, - StateAccessViolation, - VyperInternalException, -) +from vyper.exceptions import CompilerPanic, StructureException from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import VyperType +from vyper.utils import OrderedSet, StringEnum if TYPE_CHECKING: from vyper.semantics.types.module import InterfaceT, ModuleT -class _StringEnum(enum.Enum): - @staticmethod - def auto(): - return enum.auto() +class FunctionVisibility(StringEnum): + EXTERNAL = enum.auto() + INTERNAL = enum.auto() + DEPLOY = enum.auto() - # Must be first, or else won't work, specifies what .value is - def _generate_next_value_(name, start, count, last_values): - return name.lower() - # Override ValueError with our own internal exception - @classmethod - def _missing_(cls, value): - raise VyperInternalException(f"{value} is not a valid {cls.__name__}") - - @classmethod - def is_valid_value(cls, value: str) -> bool: - return value in set(o.value for o in cls) - - @classmethod - def options(cls) -> List["_StringEnum"]: - return list(cls) - - @classmethod - def values(cls) -> List[str]: - return [v.value for v in cls.options()] - - # Comparison operations - def __eq__(self, other: object) -> bool: - if not isinstance(other, self.__class__): - raise CompilerPanic("Can only compare like types.") - return self is other - - # Python normally does __ne__(other) ==> not self.__eq__(other) - - def __lt__(self, other: object) -> bool: - if not isinstance(other, self.__class__): - raise CompilerPanic("Can only compare like types.") - options = self.__class__.options() - return options.index(self) < options.index(other) # type: ignore - - def __le__(self, other: object) -> bool: - return self.__eq__(other) or self.__lt__(other) - - def __gt__(self, other: object) -> bool: - return not self.__le__(other) - - def __ge__(self, other: object) -> bool: - return self.__eq__(other) or self.__gt__(other) - - -class FunctionVisibility(_StringEnum): - # TODO: these can just be enum.auto() right? - EXTERNAL = _StringEnum.auto() - INTERNAL = _StringEnum.auto() - - -class StateMutability(_StringEnum): - # TODO: these can just be enum.auto() right? - PURE = _StringEnum.auto() - VIEW = _StringEnum.auto() - NONPAYABLE = _StringEnum.auto() - PAYABLE = _StringEnum.auto() +class StateMutability(StringEnum): + PURE = enum.auto() + VIEW = enum.auto() + NONPAYABLE = enum.auto() + PAYABLE = enum.auto() @classmethod def from_abi(cls, abi_dict: Dict) -> "StateMutability": @@ -103,71 +48,40 @@ def from_abi(cls, abi_dict: Dict) -> "StateMutability": # and variables) and Constancy (in codegen). context.Constancy can/should # probably be refactored away though as those kinds of checks should be done # during analysis. -class Modifiability(enum.IntEnum): - # is writeable/can result in arbitrary state or memory changes - MODIFIABLE = enum.auto() - - # could potentially add more fine-grained here as needed, like - # CONSTANT_AFTER_DEPLOY, TX_CONSTANT, BLOCK_CONSTANT, etc. +class Modifiability(StringEnum): + # compile-time / always constant + CONSTANT = enum.auto() # things that are constant within the current message call, including # block.*, msg.*, tx.* and immutables RUNTIME_CONSTANT = enum.auto() - # compile-time / always constant - CONSTANT = enum.auto() - - -class DataPosition: - _location: DataLocation - - -class CalldataOffset(DataPosition): - __slots__ = ("dynamic_offset", "static_offset") - _location = DataLocation.CALLDATA - - def __init__(self, static_offset, dynamic_offset=None): - self.static_offset = static_offset - self.dynamic_offset = dynamic_offset - - def __repr__(self): - if self.dynamic_offset is not None: - return f"" - else: - return f"" - - -class MemoryOffset(DataPosition): - __slots__ = ("offset",) - _location = DataLocation.MEMORY - - def __init__(self, offset): - self.offset = offset - - def __repr__(self): - return f"" - - -class StorageSlot(DataPosition): - __slots__ = ("position",) - _location = DataLocation.STORAGE + # could potentially add more fine-grained here as needed, like + # CONSTANT_AFTER_DEPLOY, TX_CONSTANT, BLOCK_CONSTANT, etc. - def __init__(self, position): - self.position = position + # is writeable/can result in arbitrary state or memory changes + MODIFIABLE = enum.auto() - def __repr__(self): - return f"" + @classmethod + def from_state_mutability(cls, mutability: StateMutability): + if mutability == StateMutability.PURE: + return cls.CONSTANT + if mutability == StateMutability.VIEW: + return cls.RUNTIME_CONSTANT + # sanity check in case more StateMutability levels are added in the future + assert mutability in (StateMutability.PAYABLE, StateMutability.NONPAYABLE) + return cls.MODIFIABLE -class CodeOffset(DataPosition): - __slots__ = ("offset",) - _location = DataLocation.CODE +@dataclass +class VarOffset: + position: int - def __init__(self, offset): - self.offset = offset - def __repr__(self): - return f"" +class ModuleOwnership(StringEnum): + NO_OWNERSHIP = enum.auto() # readable + USES = enum.auto() # writeable + INITIALIZES = enum.auto() # initializes # base class for things that are the "result" of analysis @@ -178,6 +92,9 @@ class AnalysisResult: @dataclass class ModuleInfo(AnalysisResult): module_t: "ModuleT" + alias: str + ownership: ModuleOwnership = ModuleOwnership.NO_OWNERSHIP + ownership_decl: Optional[vy_ast.VyperNode] = None @property def module_node(self): @@ -188,6 +105,16 @@ def module_node(self): def typ(self): return self.module_t + def set_ownership(self, module_ownership: ModuleOwnership, node: Optional[vy_ast.VyperNode]): + if self.ownership != ModuleOwnership.NO_OWNERSHIP: + raise StructureException( + f"ownership already set to `{self.ownership}`", node, self.ownership_decl + ) + self.ownership = module_ownership + + def __hash__(self): + return hash(id(self.module_t)) + @dataclass class ImportInfo(AnalysisResult): @@ -199,6 +126,21 @@ class ImportInfo(AnalysisResult): node: vy_ast.VyperNode +# analysis result of InitializesDecl +@dataclass +class InitializesInfo(AnalysisResult): + module_info: ModuleInfo + dependencies: list[ModuleInfo] + node: Optional[vy_ast.VyperNode] = None + + +# analysis result of UsesDecl +@dataclass +class UsesInfo(AnalysisResult): + used_modules: list[ModuleInfo] + node: Optional[vy_ast.VyperNode] = None + + @dataclass class VarInfo: """ @@ -221,22 +163,21 @@ def __hash__(self): return hash(id(self)) def __post_init__(self): + self.position = None self._modification_count = 0 - def set_position(self, position: DataPosition) -> None: - if hasattr(self, "position"): + def set_position(self, position: VarOffset) -> None: + if self.position is not None: raise CompilerPanic("Position was already assigned") - if self.location != position._location: - if self.location == DataLocation.UNSET: - self.location = position._location - elif self.is_transient and position._location == DataLocation.STORAGE: - # CMC 2023-12-31 - use same allocator for storage and transient - # for now, this should be refactored soon. - pass - else: - raise CompilerPanic("Incompatible locations") + assert isinstance(position, VarOffset) # sanity check self.position = position + def is_module_variable(self): + return self.location not in (DataLocation.UNSET, DataLocation.MEMORY) + + def get_size(self) -> int: + return self.typ.get_size_in(self.location) + @property def is_transient(self): return self.location == DataLocation.TRANSIENT @@ -260,9 +201,13 @@ class ExprInfo: typ: VyperType var_info: Optional[VarInfo] = None + module_info: Optional[ModuleInfo] = None location: DataLocation = DataLocation.UNSET modifiability: Modifiability = Modifiability.MODIFIABLE + # the chain of attribute parents for this expr + attribute_chain: list["ExprInfo"] = field(default_factory=list) + def __post_init__(self): should_match = ("typ", "location", "modifiability") if self.var_info is not None: @@ -270,65 +215,48 @@ def __post_init__(self): if getattr(self.var_info, attr) != getattr(self, attr): raise CompilerPanic("Bad analysis: non-matching {attr}: {self}") + self._writes: OrderedSet[VarInfo] = OrderedSet() + self._reads: OrderedSet[VarInfo] = OrderedSet() + + # find exprinfo in the attribute chain which has a varinfo + # e.x. `x` will return varinfo for `x` + # `module.foo` will return varinfo for `module.foo` + # `self.my_struct.x.y` will return varinfo for `self.my_struct` + def get_root_varinfo(self) -> Optional[VarInfo]: + for expr_info in self.attribute_chain + [self]: + if expr_info.var_info is not None: + return expr_info.var_info + return None + @classmethod - def from_varinfo(cls, var_info: VarInfo) -> "ExprInfo": + def from_varinfo(cls, var_info: VarInfo, attribute_chain=None) -> "ExprInfo": return cls( var_info.typ, var_info=var_info, location=var_info.location, modifiability=var_info.modifiability, + attribute_chain=attribute_chain or [], ) @classmethod - def from_moduleinfo(cls, module_info: ModuleInfo) -> "ExprInfo": - return cls(module_info.module_t) + def from_moduleinfo(cls, module_info: ModuleInfo, attribute_chain=None) -> "ExprInfo": + modifiability = Modifiability.RUNTIME_CONSTANT + if module_info.ownership >= ModuleOwnership.USES: + modifiability = Modifiability.MODIFIABLE - def copy_with_type(self, typ: VyperType) -> "ExprInfo": + return cls( + module_info.module_t, + module_info=module_info, + modifiability=modifiability, + attribute_chain=attribute_chain or [], + ) + + def copy_with_type(self, typ: VyperType, attribute_chain=None) -> "ExprInfo": """ Return a copy of the ExprInfo but with the type set to something else """ to_copy = ("location", "modifiability") fields = {k: getattr(self, k) for k in to_copy} + if attribute_chain is not None: + fields["attribute_chain"] = attribute_chain return self.__class__(typ=typ, **fields) - - def validate_modification(self, node: vy_ast.VyperNode, mutability: StateMutability) -> None: - """ - Validate an attempt to modify this value. - - Raises if the value is a constant or involves an invalid operation. - - Arguments - --------- - node : Assign | AugAssign | Call - Vyper ast node of the modifying action. - mutability: StateMutability - The mutability of the context (e.g., pure function) we are currently in - """ - if mutability <= StateMutability.VIEW and self.location == DataLocation.STORAGE: - raise StateAccessViolation( - f"Cannot modify storage in a {mutability.value} function", node - ) - - if self.location == DataLocation.CALLDATA: - raise ImmutableViolation("Cannot write to calldata", node) - - if self.modifiability == Modifiability.RUNTIME_CONSTANT: - if self.location == DataLocation.CODE: - if node.get_ancestor(vy_ast.FunctionDef).get("name") != "__init__": - raise ImmutableViolation("Immutable value cannot be written to", node) - - # special handling for immutable variables in the ctor - # TODO: we probably want to remove this restriction. - if self.var_info._modification_count: # type: ignore - raise ImmutableViolation( - "Immutable value cannot be modified after assignment", node - ) - self.var_info._modification_count += 1 # type: ignore - else: - raise ImmutableViolation("Environment variable cannot be written to", node) - - if self.modifiability == Modifiability.CONSTANT: - raise ImmutableViolation("Constant value cannot be written to", node) - - if isinstance(node, vy_ast.AugAssign): - self.typ.validate_numeric_op(node) diff --git a/vyper/semantics/analysis/constant_folding.py b/vyper/semantics/analysis/constant_folding.py index bfcc473d09..3522383167 100644 --- a/vyper/semantics/analysis/constant_folding.py +++ b/vyper/semantics/analysis/constant_folding.py @@ -113,7 +113,7 @@ def visit_Attribute(self, node) -> vy_ast.ExprNode: varinfo = module_t.get_member(node.attr, node) return varinfo.decl_node.value.get_folded_value() - except (VyperException, AttributeError): + except (VyperException, AttributeError, KeyError): raise UnfoldableNode("not a module") def visit_UnaryOp(self, node): diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index 88679a4b09..604bc6b594 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -1,11 +1,12 @@ -# TODO this module doesn't really belong in "validation" -from typing import Dict, List +from collections import defaultdict +from typing import Generic, TypeVar from vyper import ast as vy_ast -from vyper.exceptions import StorageLayoutException -from vyper.semantics.analysis.base import CodeOffset, StorageSlot +from vyper.evm.opcodes import version_check +from vyper.exceptions import CompilerPanic, StorageLayoutException +from vyper.semantics.analysis.base import VarOffset +from vyper.semantics.data_locations import DataLocation from vyper.typing import StorageLayout -from vyper.utils import ceil32 def set_data_positions( @@ -20,24 +21,76 @@ def set_data_positions( vyper_module : vy_ast.Module Top-level Vyper AST node that has already been annotated with type data. """ - code_offsets = set_code_offsets(vyper_module) - storage_slots = ( - set_storage_slots_with_overrides(vyper_module, storage_layout_overrides) - if storage_layout_overrides is not None - else set_storage_slots(vyper_module) - ) + if storage_layout_overrides is not None: + # extract code layout with no overrides + code_offsets = _allocate_layout_r(vyper_module, immutables_only=True)["code_layout"] + storage_slots = set_storage_slots_with_overrides(vyper_module, storage_layout_overrides) + return {"storage_layout": storage_slots, "code_layout": code_offsets} - return {"storage_layout": storage_slots, "code_layout": code_offsets} + ret = _allocate_layout_r(vyper_module) + assert isinstance(ret, defaultdict) + return dict(ret) # convert back to dict -class StorageAllocator: +_T = TypeVar("_T") +_K = TypeVar("_K") + + +class InsertableOnceDict(Generic[_T, _K], dict[_T, _K]): + def __setitem__(self, k, v): + if k in self: + raise ValueError(f"{k} is already in dict!") + super().__setitem__(k, v) + + +class SimpleAllocator: + def __init__(self, max_slot: int = 2**256, starting_slot: int = 0): + # Allocate storage slots from 0 + # note storage is word-addressable, not byte-addressable + self._slot = starting_slot + self._max_slot = max_slot + + def allocate_slot(self, n, var_name, node=None): + ret = self._slot + if self._slot + n >= self._max_slot: + raise StorageLayoutException( + f"Invalid storage slot, tried to allocate" + f" slots {self._slot} through {self._slot + n}", + node, + ) + self._slot += n + return ret + + +class Allocators: + storage_allocator: SimpleAllocator + transient_storage_allocator: SimpleAllocator + immutables_allocator: SimpleAllocator + + def __init__(self): + self.storage_allocator = SimpleAllocator(max_slot=2**256) + self.transient_storage_allocator = SimpleAllocator(max_slot=2**256) + self.immutables_allocator = SimpleAllocator(max_slot=0x6000) + + def get_allocator(self, location: DataLocation): + if location == DataLocation.STORAGE: + return self.storage_allocator + if location == DataLocation.TRANSIENT: + return self.transient_storage_allocator + if location == DataLocation.CODE: + return self.immutables_allocator + + raise CompilerPanic("unreachable") # pragma: nocover + + +class OverridingStorageAllocator: """ Keep track of which storage slots have been used. If there is a collision of storage slots, this will raise an error and fail to compile """ def __init__(self): - self.occupied_slots: Dict[int, str] = {} + self.occupied_slots: dict[int, str] = {} def reserve_slot_range(self, first_slot: int, n_slots: int, var_name: str) -> None: """ @@ -48,7 +101,7 @@ def reserve_slot_range(self, first_slot: int, n_slots: int, var_name: str) -> No list_to_check = [x + first_slot for x in range(n_slots)] self._reserve_slots(list_to_check, var_name) - def _reserve_slots(self, slots: List[int], var_name: str) -> None: + def _reserve_slots(self, slots: list[int], var_name: str) -> None: for slot in slots: self._reserve_slot(slot, var_name) @@ -70,12 +123,13 @@ def set_storage_slots_with_overrides( vyper_module: vy_ast.Module, storage_layout_overrides: StorageLayout ) -> StorageLayout: """ - Parse module-level Vyper AST to calculate the layout of storage variables. + Set storage layout given a layout override file. Returns the layout as a dict of variable name -> variable info + (Doesn't handle modules, or transient storage) """ - ret: Dict[str, Dict] = {} - reserved_slots = StorageAllocator() + ret: InsertableOnceDict[str, dict] = InsertableOnceDict() + reserved_slots = OverridingStorageAllocator() # Search through function definitions to find non-reentrant functions for node in vyper_module.get_children(vy_ast.FunctionDef): @@ -90,7 +144,7 @@ def set_storage_slots_with_overrides( # re-entrant key was already identified if variable_name in ret: _slot = ret[variable_name]["slot"] - type_.set_reentrancy_key_position(StorageSlot(_slot)) + type_.set_reentrancy_key_position(VarOffset(_slot)) continue # Expect to find this variable within the storage layout override @@ -100,7 +154,7 @@ def set_storage_slots_with_overrides( # from using the same slot reserved_slots.reserve_slot_range(reentrant_slot, 1, variable_name) - type_.set_reentrancy_key_position(StorageSlot(reentrant_slot)) + type_.set_reentrancy_key_position(VarOffset(reentrant_slot)) ret[variable_name] = {"type": "nonreentrant lock", "slot": reentrant_slot} else: @@ -125,7 +179,7 @@ def set_storage_slots_with_overrides( # Ensure that all required storage slots are reserved, and prevents other variables # from using these slots reserved_slots.reserve_slot_range(var_slot, storage_length, node.target.id) - varinfo.set_position(StorageSlot(var_slot)) + varinfo.set_position(VarOffset(var_slot)) ret[node.target.id] = {"type": str(varinfo.typ), "slot": var_slot} else: @@ -138,105 +192,108 @@ def set_storage_slots_with_overrides( return ret -class SimpleStorageAllocator: - def __init__(self, starting_slot: int = 0): - self._slot = starting_slot +def _get_allocatable(vyper_module: vy_ast.Module) -> list[vy_ast.VyperNode]: + allocable = (vy_ast.InitializesDecl, vy_ast.VariableDecl) + return [node for node in vyper_module.body if isinstance(node, allocable)] - def allocate_slot(self, n, var_name): - ret = self._slot - if self._slot + n >= 2**256: - raise StorageLayoutException( - f"Invalid storage slot for var {var_name}, tried to allocate" - f" slots {self._slot} through {self._slot + n}" - ) - self._slot += n - return ret +def get_reentrancy_key_location() -> DataLocation: + if version_check(begin="cancun"): + return DataLocation.TRANSIENT + return DataLocation.STORAGE -def set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout: + +_LAYOUT_KEYS = { + DataLocation.CODE: "code_layout", + DataLocation.TRANSIENT: "transient_storage_layout", + DataLocation.STORAGE: "storage_layout", +} + + +def _allocate_layout_r( + vyper_module: vy_ast.Module, allocators: Allocators = None, immutables_only=False +) -> StorageLayout: """ Parse module-level Vyper AST to calculate the layout of storage variables. Returns the layout as a dict of variable name -> variable info """ - # Allocate storage slots from 0 - # note storage is word-addressable, not byte-addressable - allocator = SimpleStorageAllocator() + if allocators is None: + allocators = Allocators() - ret: Dict[str, Dict] = {} + ret: defaultdict[str, InsertableOnceDict[str, dict]] = defaultdict(InsertableOnceDict) for node in vyper_module.get_children(vy_ast.FunctionDef): + if immutables_only: + break + type_ = node._metadata["func_type"] if type_.nonreentrant is None: continue variable_name = f"nonreentrant.{type_.nonreentrant}" + reentrancy_key_location = get_reentrancy_key_location() + layout_key = _LAYOUT_KEYS[reentrancy_key_location] # a nonreentrant key can appear many times in a module but it # only takes one slot. after the first time we see it, do not # increment the storage slot. - if variable_name in ret: - _slot = ret[variable_name]["slot"] - type_.set_reentrancy_key_position(StorageSlot(_slot)) + if variable_name in ret[layout_key]: + _slot = ret[layout_key][variable_name]["slot"] + type_.set_reentrancy_key_position(VarOffset(_slot)) continue # TODO use one byte - or bit - per reentrancy key # requires either an extra SLOAD or caching the value of the # location in memory at entrance - slot = allocator.allocate_slot(1, variable_name) + allocator = allocators.get_allocator(reentrancy_key_location) + slot = allocator.allocate_slot(1, variable_name, node) - type_.set_reentrancy_key_position(StorageSlot(slot)) + type_.set_reentrancy_key_position(VarOffset(slot)) # TODO this could have better typing but leave it untyped until # we nail down the format better - ret[variable_name] = {"type": "nonreentrant lock", "slot": slot} - - for node in vyper_module.get_children(vy_ast.VariableDecl): - # skip non-storage variables - if node.is_constant or node.is_immutable: + ret[layout_key][variable_name] = {"type": "nonreentrant lock", "slot": slot} + + for node in _get_allocatable(vyper_module): + if isinstance(node, vy_ast.InitializesDecl): + module_info = node._metadata["initializes_info"].module_info + module_layout = _allocate_layout_r(module_info.module_node, allocators) + module_alias = module_info.alias + for layout_key in module_layout.keys(): + assert layout_key in _LAYOUT_KEYS.values() + ret[layout_key][module_alias] = module_layout[layout_key] continue + assert isinstance(node, vy_ast.VariableDecl) + # skip non-storage variables varinfo = node.target._metadata["varinfo"] - type_ = varinfo.typ - - # CMC 2021-07-23 note that HashMaps get assigned a slot here. - # I'm not sure if it's safe to avoid allocating that slot - # for HashMaps because downstream code might use the slot - # ID as a salt. - n_slots = type_.storage_size_in_words - slot = allocator.allocate_slot(n_slots, node.target.id) - - varinfo.set_position(StorageSlot(slot)) - - # this could have better typing but leave it untyped until - # we understand the use case better - ret[node.target.id] = {"type": str(type_), "slot": slot} - - return ret - - -def set_calldata_offsets(fn_node: vy_ast.FunctionDef) -> None: - pass + if not varinfo.is_module_variable(): + continue + location = varinfo.location + if immutables_only and location != DataLocation.CODE: + continue -def set_memory_offsets(fn_node: vy_ast.FunctionDef) -> None: - pass + allocator = allocators.get_allocator(location) + size = varinfo.get_size() + # CMC 2021-07-23 note that HashMaps get assigned a slot here + # using the same allocator (even though there is not really + # any risk of physical overlap) + offset = allocator.allocate_slot(size, node.target.id, node) -def set_code_offsets(vyper_module: vy_ast.Module) -> Dict: - ret = {} - offset = 0 + varinfo.set_position(VarOffset(offset)) - for node in vyper_module.get_children(vy_ast.VariableDecl, filters={"is_immutable": True}): - varinfo = node.target._metadata["varinfo"] + layout_key = _LAYOUT_KEYS[location] type_ = varinfo.typ - varinfo.set_position(CodeOffset(offset)) - - len_ = ceil32(type_.size_in_bytes) - # this could have better typing but leave it untyped until # we understand the use case better - ret[node.target.id] = {"type": str(type_), "offset": offset, "length": len_} - - offset += len_ + if location == DataLocation.CODE: + item = {"type": str(type_), "length": size, "offset": offset} + elif location in (DataLocation.STORAGE, DataLocation.TRANSIENT): + item = {"type": str(type_), "slot": offset} + else: # pragma: nocover + raise CompilerPanic("unreachable") + ret[layout_key][node.target.id] = item return ret diff --git a/vyper/semantics/analysis/global_.py b/vyper/semantics/analysis/global_.py new file mode 100644 index 0000000000..92cdf35c5d --- /dev/null +++ b/vyper/semantics/analysis/global_.py @@ -0,0 +1,80 @@ +from collections import defaultdict + +from vyper.exceptions import ExceptionList, InitializerException +from vyper.semantics.analysis.base import InitializesInfo, UsesInfo +from vyper.semantics.analysis.import_graph import ImportGraph +from vyper.semantics.analysis.module import validate_module_semantics_r +from vyper.semantics.types.module import ModuleT + + +def validate_semantics(module_ast, input_bundle, is_interface=False) -> ModuleT: + ret = validate_module_semantics_r(module_ast, input_bundle, ImportGraph(), is_interface) + + _validate_global_initializes_constraint(ret) + + return ret + + +def _collect_used_modules_r(module_t): + ret: defaultdict[ModuleT, list[UsesInfo]] = defaultdict(list) + + for uses_decl in module_t.uses_decls: + for used_module in uses_decl._metadata["uses_info"].used_modules: + ret[used_module.module_t].append(uses_decl) + + # recurse + used_modules = _collect_used_modules_r(used_module.module_t) + for k, v in used_modules.items(): + ret[k].extend(v) + + # also recurse into modules used by initialized modules + for i in module_t.initialized_modules: + used_modules = _collect_used_modules_r(i.module_info.module_t) + for k, v in used_modules.items(): + ret[k].extend(v) + + return ret + + +def _collect_initialized_modules_r(module_t, seen=None): + seen: dict[ModuleT, InitializesInfo] = seen or {} + + # list of InitializedInfo + initialized_infos = module_t.initialized_modules + + for i in initialized_infos: + initialized_module_t = i.module_info.module_t + if initialized_module_t in seen: + seen_nodes = (i.node, seen[initialized_module_t].node) + raise InitializerException(f"`{i.module_info.alias}` initialized twice!", *seen_nodes) + seen[initialized_module_t] = i + + _collect_initialized_modules_r(initialized_module_t, seen) + + return seen + + +# validate that each module which is `used` in the import graph is +# `initialized`. +def _validate_global_initializes_constraint(module_t: ModuleT): + all_used_modules = _collect_used_modules_r(module_t) + all_initialized_modules = _collect_initialized_modules_r(module_t) + + err_list = ExceptionList() + + for u, uses in all_used_modules.items(): + if u not in all_initialized_modules: + found_module = module_t.find_module_info(u) + if found_module is not None: + hint = f"add `initializes: {found_module.alias}` to the top level of " + hint += "your main contract" + else: + # CMC 2024-02-06 is this actually reachable? + hint = f"ensure `{module_t}` is imported in your main contract!" + err_list.append( + InitializerException( + f"module `{u}` is used but never initialized!", *uses, hint=hint + ) + ) + + err_list.raise_if_not_empty() diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 91cc0ebdf8..d96215ede0 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -1,8 +1,11 @@ +# CMC 2024-02-03 TODO: split me into function.py and expr.py + from typing import Optional from vyper import ast as vy_ast from vyper.ast.validation import validate_call_args from vyper.exceptions import ( + CallViolation, ExceptionList, FunctionDeclarationException, ImmutableViolation, @@ -16,7 +19,7 @@ VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import Modifiability, VarInfo +from vyper.semantics.analysis.base import Modifiability, ModuleOwnership, VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.utils import ( get_common_types, @@ -54,13 +57,12 @@ def validate_functions(vy_module: vy_ast.Module) -> None: """Analyzes a vyper ast and validates the function bodies""" - err_list = ExceptionList() namespace = get_namespace() for node in vy_module.get_children(vy_ast.FunctionDef): with namespace.enter_scope(): try: - analyzer = FunctionNodeVisitor(vy_module, node, namespace) + analyzer = FunctionAnalyzer(vy_module, node, namespace) analyzer.analyze() except VyperException as e: err_list.append(e) @@ -181,7 +183,7 @@ def _validate_self_reference(node: vy_ast.Name) -> None: raise StateAccessViolation("not allowed to query self in pure functions", node) -class FunctionNodeVisitor(VyperNodeVisitorBase): +class FunctionAnalyzer(VyperNodeVisitorBase): ignored_types = (vy_ast.Pass,) scope_name = "function" @@ -192,7 +194,7 @@ def __init__( self.fn_node = fn_node self.namespace = namespace self.func = fn_node._metadata["func_type"] - self.expr_visitor = ExprVisitor(self.func) + self.expr_visitor = ExprVisitor(self) def analyze(self): # allow internal function params to be mutable @@ -270,21 +272,94 @@ def _assign_helper(self, node): raise StructureException("Right-hand side of assignment cannot be a tuple", node.value) target = get_expr_info(node.target) - if isinstance(target.typ, HashMapT): - raise StructureException( - "Left-hand side of assignment cannot be a HashMap without a key", node - ) - target.validate_modification(node, self.func.mutability) + # check mutability of the function + self._handle_modification(node.target) self.expr_visitor.visit(node.value, target.typ) self.expr_visitor.visit(node.target, target.typ) + def _handle_modification(self, target: vy_ast.ExprNode): + if isinstance(target, vy_ast.Tuple): + for item in target.elements: + self._handle_modification(item) + return + + # check a modification of `target`. validate the modification is + # valid, and log the modification in relevant data structures. + func_t = self.func + info = get_expr_info(target) + + if isinstance(info.typ, HashMapT): + raise StructureException( + "Left-hand side of assignment cannot be a HashMap without a key" + ) + + if ( + info.location in (DataLocation.STORAGE, DataLocation.TRANSIENT) + and func_t.mutability <= StateMutability.VIEW + ): + raise StateAccessViolation( + f"Cannot modify {info.location} variable in a {func_t.mutability} function" + ) + + if info.location == DataLocation.CALLDATA: + raise ImmutableViolation("Cannot write to calldata") + + if info.modifiability == Modifiability.RUNTIME_CONSTANT: + if info.location == DataLocation.CODE: + if not func_t.is_constructor: + raise ImmutableViolation("Immutable value cannot be written to") + + # handle immutables + if info.var_info is not None: # don't handle complex (struct,array) immutables + # special handling for immutable variables in the ctor + # TODO: maybe we want to remove this restriction. + if info.var_info._modification_count != 0: + raise ImmutableViolation( + "Immutable value cannot be modified after assignment" + ) + info.var_info._modification_count += 1 + else: + raise ImmutableViolation("Environment variable cannot be written to") + + if info.modifiability == Modifiability.CONSTANT: + raise ImmutableViolation("Constant value cannot be written to.") + + var_info = info.get_root_varinfo() + assert var_info is not None + + info._writes.add(var_info) + + def _check_module_use(self, target: vy_ast.ExprNode): + module_infos = [] + for t in get_expr_info(target).attribute_chain: + if t.module_info is not None: + module_infos.append(t.module_info) + + if len(module_infos) == 0: + return + + for module_info in module_infos: + if module_info.ownership < ModuleOwnership.USES: + msg = f"Cannot access `{module_info.alias}` state!" + hint = f"add `uses: {module_info.alias}` or " + hint += f"`initializes: {module_info.alias}` as " + hint += "a top-level statement to your contract" + raise ImmutableViolation(msg, hint=hint) + + # the leftmost- referenced module + root_module_info = module_infos[0] + + # log the access + self.func._used_modules.add(root_module_info) + def visit_Assign(self, node): self._assign_helper(node) def visit_AugAssign(self, node): self._assign_helper(node) + node.target._expr_info.typ.validate_numeric_op(node) def visit_Break(self, node): for_node = node.get_ancestor(vy_ast.For) @@ -309,35 +384,13 @@ def visit_Expr(self, node): raise StructureException("Expressions without assignment are disallowed", node) fn_type = get_exact_type_from_node(node.value.func) + if is_type_t(fn_type, EventT): raise StructureException("To call an event you must use the `log` statement", node) if is_type_t(fn_type, StructT): raise StructureException("Struct creation without assignment is disallowed", node) - if isinstance(fn_type, ContractFunctionT): - if ( - fn_type.mutability > StateMutability.VIEW - and self.func.mutability <= StateMutability.VIEW - ): - raise StateAccessViolation( - f"Cannot call a mutating function from a {self.func.mutability.value} function", - node, - ) - - if ( - self.func.mutability == StateMutability.PURE - and fn_type.mutability != StateMutability.PURE - ): - raise StateAccessViolation( - "Cannot call non-pure function from a pure function", node - ) - - if isinstance(fn_type, MemberFunctionT) and fn_type.is_modifying: - # it's a dotted function call like dynarray.pop() - expr_info = get_expr_info(node.value.func.value) - expr_info.validate_modification(node, self.func.mutability) - # NOTE: fetch_call_return validates call args. return_value = map_void(fn_type.fetch_call_return(node.value)) if ( @@ -457,7 +510,7 @@ def visit_Log(self, node): raise StructureException("Value is not an event", node.value) if self.func.mutability <= StateMutability.VIEW: raise StructureException( - f"Cannot emit logs from {self.func.mutability.value.lower()} functions", node + f"Cannot emit logs from {self.func.mutability} functions", node ) t = map_void(f.fetch_call_return(node.value)) # CMC 2024-02-05 annotate the event type for codegen usage @@ -493,10 +546,20 @@ def visit_Return(self, node): class ExprVisitor(VyperNodeVisitorBase): - scope_name = "function" + def __init__(self, function_analyzer: Optional[FunctionAnalyzer] = None): + self.function_analyzer = function_analyzer + + @property + def func(self): + if self.function_analyzer is None: + return None + return self.function_analyzer.func - def __init__(self, fn_node: Optional[ContractFunctionT] = None): - self.func = fn_node + @property + def scope_name(self): + if self.func is not None: + return "function" + return "module" def visit(self, node, typ): if typ is not VOID_TYPE and not isinstance(typ, TYPE_T): @@ -509,6 +572,24 @@ def visit(self, node, typ): # annotate node._metadata["type"] = typ + if not isinstance(typ, TYPE_T): + info = get_expr_info(node) # get_expr_info fills in node._expr_info + + # log variable accesses. + # (note writes will get logged as both read+write) + varinfo = info.var_info + if varinfo is not None: + info._reads.add(varinfo) + + if self.func: + variable_accesses = info._writes | info._reads + for s in variable_accesses: + if s.is_module_variable(): + self.function_analyzer._check_module_use(node) + + self.func._variable_writes.update(info._writes) + self.func._variable_reads.update(info._reads) + # validate and annotate folded value if node.has_folded_value: folded_node = node.get_folded_value() @@ -547,42 +628,77 @@ def visit_BoolOp(self, node: vy_ast.BoolOp, typ: VyperType) -> None: for value in node.values: self.visit(value, BoolT()) + def _check_call_mutability(self, call_mutability: StateMutability): + # note: payable can be called from nonpayable functions + ok = ( + call_mutability <= self.func.mutability + or self.func.mutability >= StateMutability.NONPAYABLE + ) + if not ok: + msg = f"Cannot call a {call_mutability} function from a {self.func.mutability} function" + raise StateAccessViolation(msg) + def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: - call_type = get_exact_type_from_node(node.func) - self.visit(node.func, call_type) + func_info = get_expr_info(node.func, is_callable=True) + func_type = func_info.typ + self.visit(node.func, func_type) - if isinstance(call_type, ContractFunctionT): + if isinstance(func_type, ContractFunctionT): # function calls - if self.func and call_type.is_internal: - self.func.called_functions.add(call_type) - for arg, typ in zip(node.args, call_type.argument_types): + + func_info._writes.update(func_type._variable_writes) + func_info._reads.update(func_type._variable_reads) + + if self.function_analyzer: + if func_type.is_internal: + self.func.called_functions.add(func_type) + + self._check_call_mutability(func_type.mutability) + + # check that if the function accesses state, the defining + # module has been `used` or `initialized`. + for s in func_type._variable_accesses: + if s.is_module_variable(): + self.function_analyzer._check_module_use(node.func) + + if func_type.is_deploy and not self.func.is_deploy: + raise CallViolation( + f"Cannot call an @{func_type.visibility} function from " + f"an @{self.func.visibility} function!", + node, + ) + + for arg, typ in zip(node.args, func_type.argument_types): self.visit(arg, typ) for kwarg in node.keywords: # We should only see special kwargs - typ = call_type.call_site_kwargs[kwarg.arg].typ + typ = func_type.call_site_kwargs[kwarg.arg].typ self.visit(kwarg.value, typ) - elif is_type_t(call_type, EventT): + elif is_type_t(func_type, EventT): # events have no kwargs - expected_types = call_type.typedef.arguments.values() + expected_types = func_type.typedef.arguments.values() # type: ignore for arg, typ in zip(node.args, expected_types): self.visit(arg, typ) - elif is_type_t(call_type, StructT): + elif is_type_t(func_type, StructT): # struct ctors # ctors have no kwargs - expected_types = call_type.typedef.members.values() + expected_types = func_type.typedef.members.values() # type: ignore for value, arg_type in zip(node.args[0].values, expected_types): self.visit(value, arg_type) - elif isinstance(call_type, MemberFunctionT): - assert len(node.args) == len(call_type.arg_types) - for arg, arg_type in zip(node.args, call_type.arg_types): + elif isinstance(func_type, MemberFunctionT): + if func_type.is_modifying and self.function_analyzer is not None: + # TODO refactor this + self.function_analyzer._handle_modification(node.func) + assert len(node.args) == len(func_type.arg_types) + for arg, arg_type in zip(node.args, func_type.arg_types): self.visit(arg, arg_type) else: # builtin functions - arg_types = call_type.infer_arg_types(node, expected_return_typ=typ) + arg_types = func_type.infer_arg_types(node, expected_return_typ=typ) # type: ignore for arg, arg_type in zip(node.args, arg_types): self.visit(arg, arg_type) - kwarg_types = call_type.infer_kwarg_types(node) + kwarg_types = func_type.infer_kwarg_types(node) # type: ignore for kwarg in node.keywords: self.visit(kwarg.value, kwarg_types[kwarg.arg]) @@ -638,8 +754,10 @@ def visit_List(self, node: vy_ast.List, typ: VyperType) -> None: self.visit(element, typ.value_type) def visit_Name(self, node: vy_ast.Name, typ: VyperType) -> None: - if self.func and self.func.mutability == StateMutability.PURE: - _validate_self_reference(node) + if self.func: + # TODO: refactor to use expr_info mutability + if self.func.mutability == StateMutability.PURE: + _validate_self_reference(node) def visit_Subscript(self, node: vy_ast.Subscript, typ: VyperType) -> None: if isinstance(typ, TYPE_T): diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index a83c2f3b7d..e50c3e6d6f 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -8,38 +8,50 @@ from vyper.compiler.input_bundle import ABIInput, FileInput, FilesystemInputBundle, InputBundle from vyper.evm.opcodes import version_check from vyper.exceptions import ( + BorrowException, CallViolation, DuplicateImport, ExceptionList, + ImmutableViolation, + InitializerException, InvalidLiteral, InvalidType, ModuleNotFound, NamespaceCollision, StateAccessViolation, StructureException, - SyntaxException, + UndeclaredDefinition, VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import ImportInfo, Modifiability, ModuleInfo, VarInfo +from vyper.semantics.analysis.base import ( + ImportInfo, + InitializesInfo, + Modifiability, + ModuleInfo, + ModuleOwnership, + UsesInfo, + VarInfo, +) from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.constant_folding import constant_fold from vyper.semantics.analysis.import_graph import ImportGraph from vyper.semantics.analysis.local import ExprVisitor, validate_functions -from vyper.semantics.analysis.utils import check_modifiability, get_exact_type_from_node +from vyper.semantics.analysis.utils import ( + check_modifiability, + get_exact_type_from_node, + get_expr_info, +) from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace from vyper.semantics.types import EventT, FlagT, InterfaceT, StructT from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.module import ModuleT from vyper.semantics.types.utils import type_from_annotation +from vyper.utils import OrderedSet -def validate_semantics(module_ast, input_bundle, is_interface=False) -> ModuleT: - return validate_semantics_r(module_ast, input_bundle, ImportGraph(), is_interface) - - -def validate_semantics_r( +def validate_module_semantics_r( module_ast: vy_ast.Module, input_bundle: InputBundle, import_graph: ImportGraph, @@ -49,6 +61,11 @@ def validate_semantics_r( Analyze a Vyper module AST node, add all module-level objects to the namespace, type-check/validate semantics and annotate with type and analysis info """ + if "type" in module_ast._metadata: + # we don't need to analyse again, skip out + assert isinstance(module_ast._metadata["type"], ModuleT) + return module_ast._metadata["type"] + validate_literal_nodes(module_ast) # validate semantics and annotate AST with type/semantics information @@ -64,6 +81,8 @@ def validate_semantics_r( # in `ContractFunction.from_vyi()` if not is_interface: validate_functions(module_ast) + analyzer.validate_initialized_modules() + analyzer.validate_used_modules() return ret @@ -121,11 +140,8 @@ def __init__( def analyze(self) -> ModuleT: # generate a `ModuleT` from the top-level node # note: also validates unique method ids - if "type" in self.ast._metadata: - assert isinstance(self.ast._metadata["type"], ModuleT) - # we don't need to analyse again, skip out - self.module_t = self.ast._metadata["type"] - return self.module_t + + assert "type" not in self.ast._metadata to_visit = self.ast.body.copy() @@ -138,6 +154,11 @@ def analyze(self) -> ModuleT: self.visit(node) to_visit.remove(node) + ownership_decls = self.ast.get_children((vy_ast.UsesDecl, vy_ast.InitializesDecl)) + for node in ownership_decls: + self.visit(node) + to_visit.remove(node) + # we can resolve constants after imports are handled. constant_fold(self.ast) @@ -179,6 +200,7 @@ def analyze(self) -> ModuleT: def analyze_call_graph(self): # get list of internal function calls made by each function + # CMC 2024-02-03 note: this could be cleaner in analysis/local.py function_defs = self.module_t.function_defs for func in function_defs: @@ -195,7 +217,9 @@ def analyze_call_graph(self): # we just want to be able to construct the call graph. continue - if isinstance(call_t, ContractFunctionT) and call_t.is_internal: + if isinstance(call_t, ContractFunctionT) and ( + call_t.is_internal or call_t.is_constructor + ): fn_t.called_functions.add(call_t) for func in function_defs: @@ -204,6 +228,106 @@ def analyze_call_graph(self): # compute reachable set and validate the call graph _compute_reachable_set(fn_t) + def validate_used_modules(self): + # check all `uses:` modules are actually used + should_use = {} + + module_t = self.ast._metadata["type"] + uses_decls = module_t.uses_decls + for decl in uses_decls: + info = decl._metadata["uses_info"] + for m in info.used_modules: + should_use[m.module_t] = (m, info) + + initialized_modules = {t.module_info.module_t: t for t in module_t.initialized_modules} + + all_used_modules = OrderedSet() + + for f in module_t.functions.values(): + for u in f._used_modules: + all_used_modules.add(u.module_t) + + for used_module in all_used_modules: + if used_module in initialized_modules: + continue + + if used_module in should_use: + del should_use[used_module] + + if len(should_use) > 0: + err_list = ExceptionList() + for used_module_info, uses_info in should_use.values(): + msg = f"`{used_module_info.alias}` is declared as used, but " + msg += f"it is not actually used in {module_t}!" + hint = f"delete `uses: {used_module_info.alias}`" + err_list.append(BorrowException(msg, uses_info.node, hint=hint)) + + err_list.raise_if_not_empty() + + def validate_initialized_modules(self): + # check all `initializes:` modules have `__init__()` called exactly once + module_t = self.ast._metadata["type"] + should_initialize = {t.module_info.module_t: t for t in module_t.initialized_modules} + # don't call `__init__()` for modules which don't have + # `__init__()` function + for m in should_initialize.copy(): + for f in m.functions.values(): + if f.is_constructor: + break + else: + del should_initialize[m] + + init_calls = [] + for f in self.ast.get_children(vy_ast.FunctionDef): + if f._metadata["func_type"].is_constructor: + init_calls = f.get_descendants(vy_ast.Call) + break + + seen_initializers = {} + for call_node in init_calls: + expr_info = call_node.func._expr_info + if expr_info is None: + # this can happen for range() calls; CMC 2024-02-05 try to + # refactor so that range() is properly tagged. + continue + + call_t = call_node.func._expr_info.typ + + if not isinstance(call_t, ContractFunctionT): + continue + + if not call_t.is_constructor: + continue + + # XXX: check this works as expected for nested attributes + initialized_module = call_node.func.value._expr_info.module_info + + if initialized_module.module_t in seen_initializers: + seen_location = seen_initializers[initialized_module.module_t] + msg = f"tried to initialize `{initialized_module.alias}`, " + msg += "but its __init__() function was already called!" + raise InitializerException(msg, call_node.func, seen_location) + + if initialized_module.module_t not in should_initialize: + msg = f"tried to initialize `{initialized_module.alias}`, " + msg += "but it is not in initializer list!" + hint = f"add `initializes: {initialized_module.alias}` " + hint += "as a top-level statement to your contract" + raise InitializerException(msg, call_node.func, hint=hint) + + del should_initialize[initialized_module.module_t] + seen_initializers[initialized_module.module_t] = call_node.func + + if len(should_initialize) > 0: + err_list = ExceptionList() + for s in should_initialize.values(): + msg = "not initialized!" + hint = f"add `{s.module_info.alias}.__init__()` to " + hint += "your `__init__()` function" + err_list.append(InitializerException(msg, s.node, hint=hint)) + + err_list.raise_if_not_empty() + def _ast_from_file(self, file: FileInput) -> vy_ast.Module: # cache ast if we have seen it before. # this gives us the additional property of object equality on @@ -218,10 +342,100 @@ def visit_ImplementsDecl(self, node): type_ = type_from_annotation(node.annotation) if not isinstance(type_, InterfaceT): - raise StructureException("Invalid interface name", node.annotation) + raise StructureException("not an interface!", node.annotation) type_.validate_implements(node) + def visit_UsesDecl(self, node): + # TODO: check duplicate uses declarations, e.g. + # uses: x + # ... + # uses: x + items = vy_ast.as_tuple(node.annotation) + + used_modules = [] + + for item in items: + module_info = get_expr_info(item).module_info + if module_info is None: + raise StructureException("not a valid module!", item) + + # note: try to refactor - not a huge fan of mutating the + # ModuleInfo after it's constructed + module_info.set_ownership(ModuleOwnership.USES, item) + + used_modules.append(module_info) + + node._metadata["uses_info"] = UsesInfo(used_modules, node) + + def visit_InitializesDecl(self, node): + module_ref = node.annotation + dependencies_ast = () + if isinstance(module_ref, vy_ast.Subscript): + dependencies_ast = vy_ast.as_tuple(module_ref.slice) + module_ref = module_ref.value + + # postcondition of InitializesDecl.validates() + assert isinstance(module_ref, (vy_ast.Name, vy_ast.Attribute)) + + module_info = get_expr_info(module_ref).module_info + if module_info is None: + raise StructureException("Not a module!", module_ref) + + used_modules = {i.module_t: i for i in module_info.module_t.used_modules} + + dependencies = [] + for named_expr in dependencies_ast: + assert isinstance(named_expr, vy_ast.NamedExpr) + + rhs_module = get_expr_info(named_expr.value).module_info + + with module_info.module_node.namespace(): + # lhs of the named_expr is evaluated in the namespace of the + # initialized module! + try: + lhs_module = get_expr_info(named_expr.target).module_info + except VyperException as e: + # try to report a common problem - user names the module in + # the current namespace instead of the initialized module + # namespace. + + # search for the module in the initialized module + found_module = module_info.module_t.find_module_info(rhs_module.module_t) + if found_module is not None: + msg = f"unknown module `{named_expr.target.id}`" + hint = f"did you mean `{found_module.alias} := {rhs_module.alias}`?" + raise UndeclaredDefinition(msg, named_expr.target, hint=hint) + + raise e from None + + if lhs_module.module_t != rhs_module.module_t: + raise StructureException( + f"{lhs_module.alias} is not {rhs_module.alias}!", named_expr + ) + dependencies.append(lhs_module) + + if lhs_module.module_t not in used_modules: + raise InitializerException( + f"`{module_info.alias}` is initialized with `{lhs_module.alias}`, " + f"but `{module_info.alias}` does not use `{lhs_module.alias}`!", + named_expr, + ) + + del used_modules[lhs_module.module_t] + + if len(used_modules) > 0: + item = next(iter(used_modules.values())) # just pick one + msg = f"`{module_info.alias}` uses `{item.alias}`, but it is not " + msg += f"initialized with `{item.alias}`" + hint = f"add `{item.alias}` to its initializer list" + raise InitializerException(msg, node, hint=hint) + + # note: try to refactor. not a huge fan of mutating the + # ModuleInfo after it's constructed + module_info.set_ownership(ModuleOwnership.INITIALIZES, node) + node._metadata["initializes_info"] = InitializesInfo(module_info, dependencies, node) + def visit_VariableDecl(self, node): name = node.get("target.id") if name is None: @@ -250,7 +464,7 @@ def visit_VariableDecl(self, node): if len(wrong_self_attribute) > 0 else "Immutable definition requires an assignment in the constructor" ) - raise SyntaxException(message, node.node_source_code, node.lineno, node.col_offset) + raise ImmutableViolation(message, node) data_loc = ( DataLocation.CODE @@ -364,11 +578,10 @@ def visit_Import(self, node): # don't handle things like `import x.y` if "." in alias: + msg = "import requires an accompanying `as` statement" suggested_alias = node.name[node.name.rfind(".") :] - suggestion = f"hint: try `import {node.name} as {suggested_alias}`" - raise StructureException( - f"import requires an accompanying `as` statement ({suggestion})", node - ) + hint = f"try `import {node.name} as {suggested_alias}`" + raise StructureException(msg, node, hint=hint) self._add_import(node, 0, node.name, alias) @@ -436,14 +649,14 @@ def _load_import_helper( module_ast = self._ast_from_file(file) with override_global_namespace(Namespace()): - module_t = validate_semantics_r( + module_t = validate_module_semantics_r( module_ast, self.input_bundle, import_graph=self._import_graph, is_interface=False, ) - return ModuleInfo(module_t) + return ModuleInfo(module_t, alias) except FileNotFoundError as e: # escape `e` from the block scope, it can make things @@ -456,7 +669,7 @@ def _load_import_helper( module_ast = self._ast_from_file(file) with override_global_namespace(Namespace()): - validate_semantics_r( + validate_module_semantics_r( module_ast, self.input_bundle, import_graph=self._import_graph, @@ -481,7 +694,7 @@ def _load_import_helper( raise ModuleNotFound(module_str, node) from err -def _parse_and_fold_ast(file: FileInput) -> vy_ast.VyperNode: +def _parse_and_fold_ast(file: FileInput) -> vy_ast.Module: ret = vy_ast.parse_to_ast( file.source_code, source_id=file.source_id, @@ -542,5 +755,7 @@ def _load_builtin_import(level: int, module_str: str) -> InterfaceT: interface_ast = _parse_and_fold_ast(file) with override_global_namespace(Namespace()): - module_t = validate_semantics(interface_ast, input_bundle, is_interface=True) + module_t = validate_module_semantics_r( + interface_ast, input_bundle, ImportGraph(), is_interface=True + ) return module_t.interface diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index abbf6a68cc..f1f0f48a86 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -61,8 +61,8 @@ class _ExprAnalyser: def __init__(self): self.namespace = get_namespace() - def get_expr_info(self, node: vy_ast.VyperNode) -> ExprInfo: - t = self.get_exact_type_from_node(node) + def get_expr_info(self, node: vy_ast.VyperNode, is_callable: bool = False) -> ExprInfo: + t = self.get_exact_type_from_node(node, include_type_exprs=is_callable) # if it's a Name, we have varinfo for it if isinstance(node, vy_ast.Name): @@ -74,7 +74,10 @@ def get_expr_info(self, node: vy_ast.VyperNode) -> ExprInfo: if isinstance(info, ModuleInfo): return ExprInfo.from_moduleinfo(info) - raise CompilerPanic("unreachable!", node) + if isinstance(info, VyperType): + return ExprInfo(TYPE_T(info)) + + raise CompilerPanic(f"unreachable! {info}", node) if isinstance(node, vy_ast.Attribute): # if it's an Attr, we check the parent exprinfo and @@ -82,30 +85,27 @@ def get_expr_info(self, node: vy_ast.VyperNode) -> ExprInfo: # note: Attribute(expr value, identifier attr) name = node.attr - info = self.get_expr_info(node.value) + info = self.get_expr_info(node.value, is_callable=is_callable) + + attribute_chain = info.attribute_chain + [info] t = info.typ.get_member(name, node) # it's a top-level variable if isinstance(t, VarInfo): - return ExprInfo.from_varinfo(t) + return ExprInfo.from_varinfo(t, attribute_chain=attribute_chain) - # it's something else, like my_struct.foo - return info.copy_with_type(t) + if isinstance(t, ModuleInfo): + return ExprInfo.from_moduleinfo(t, attribute_chain=attribute_chain) - if isinstance(node, vy_ast.Tuple): - # always use the most restrictive location re: modification - # kludge! for validate_modification in local analysis of Assign - types = [self.get_expr_info(n) for n in node.elements] - location = sorted((i.location for i in types), key=lambda k: k.value)[-1] - modifiability = sorted((i.modifiability for i in types), key=lambda k: k.value)[-1] - - return ExprInfo(t, location=location, modifiability=modifiability) + # it's something else, like my_struct.foo + return info.copy_with_type(t, attribute_chain=attribute_chain) # If it's a Subscript, propagate the subscriptable varinfo if isinstance(node, vy_ast.Subscript): info = self.get_expr_info(node.value) - return info.copy_with_type(t) + attribute_chain = info.attribute_chain + [info] + return info.copy_with_type(t, attribute_chain=attribute_chain) return ExprInfo(t) @@ -184,6 +184,7 @@ def _find_fn(self, node): def types_from_Attribute(self, node): is_self_reference = node.get("value.id") == "self" + # variable attribute, e.g. `foo.bar` t = self.get_exact_type_from_node(node.value, include_type_exprs=True) name = node.attr @@ -476,8 +477,10 @@ def get_exact_type_from_node(node): return _ExprAnalyser().get_exact_type_from_node(node, include_type_exprs=True) -def get_expr_info(node: vy_ast.VyperNode) -> ExprInfo: - return _ExprAnalyser().get_expr_info(node) +def get_expr_info(node: vy_ast.ExprNode, is_callable: bool = False) -> ExprInfo: + if node._expr_info is None: + node._expr_info = _ExprAnalyser().get_expr_info(node, is_callable) + return node._expr_info def get_common_types(*nodes: vy_ast.VyperNode, filter_fn: Callable = None) -> List: @@ -639,7 +642,7 @@ def validate_unique_method_ids(functions: List) -> None: seen.add(method_id) -def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) -> bool: +def check_modifiability(node: vy_ast.ExprNode, modifiability: Modifiability) -> bool: """ Check if the given node is not more modifiable than the given modifiability. """ @@ -665,5 +668,5 @@ def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) -> if hasattr(call_type, "check_modifiability_for_call"): return call_type.check_modifiability_for_call(node, modifiability) - value_type = get_expr_info(node) - return value_type.modifiability >= modifiability + info = get_expr_info(node) + return info.modifiability <= modifiability diff --git a/vyper/semantics/data_locations.py b/vyper/semantics/data_locations.py index cecea35a60..06245aa90d 100644 --- a/vyper/semantics/data_locations.py +++ b/vyper/semantics/data_locations.py @@ -1,10 +1,12 @@ import enum +from vyper.utils import StringEnum -class DataLocation(enum.Enum): - UNSET = 0 - MEMORY = 1 - STORAGE = 2 - CALLDATA = 3 - CODE = 4 - TRANSIENT = 5 + +class DataLocation(StringEnum): + UNSET = enum.auto() + MEMORY = enum.auto() + STORAGE = enum.auto() + CALLDATA = enum.auto() + CODE = enum.auto() + TRANSIENT = enum.auto() diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index d659276ee0..c5e10b52be 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -13,6 +13,7 @@ UnknownAttribute, ) from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions +from vyper.semantics.data_locations import DataLocation # Some fake type with an overridden `compare_type` which accepts any RHS @@ -25,7 +26,11 @@ def __init__(self, type_): self.type_ = type_ def compare_type(self, other): - return isinstance(other, self.type_) or self == other + if isinstance(other, self.type_): + return True + # compare two GenericTypeAcceptors -- they are the same if the base + # type is the same + return isinstance(other, self.__class__) and other.type_ == self.type_ class VyperType: @@ -91,6 +96,8 @@ def __hash__(self): return hash(self._get_equality_attrs()) def __eq__(self, other): + if self is other: + return True return ( type(self) is type(other) and self._get_equality_attrs() == other._get_equality_attrs() ) @@ -118,6 +125,16 @@ def abi_type(self) -> ABIType: """ raise CompilerPanic("Method must be implemented by the inherited class") + def get_size_in(self, location: DataLocation): + if location in (DataLocation.STORAGE, DataLocation.TRANSIENT): + return self.storage_size_in_words + if location == DataLocation.MEMORY: + return self.memory_bytes_required + if location == DataLocation.CODE: + return self.memory_bytes_required + + raise CompilerPanic("unreachable: invalid location {location}") # pragma: nocover + @property def memory_bytes_required(self) -> int: # alias for API compatibility with codegen @@ -341,8 +358,10 @@ def map_void(typ: Optional[VyperType]) -> VyperType: # A type type. Used internally for types which can live in expression # position, ex. constructors (events, interfaces and structs), and also # certain builtins which take types as parameters -class TYPE_T: +class TYPE_T(VyperType): def __init__(self, typedef): + super().__init__() + self.typedef = typedef def __repr__(self): diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 2d92370b9d..62f9c60585 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -19,8 +19,10 @@ from vyper.semantics.analysis.base import ( FunctionVisibility, Modifiability, + ModuleInfo, StateMutability, - StorageSlot, + VarInfo, + VarOffset, ) from vyper.semantics.analysis.utils import ( check_modifiability, @@ -112,10 +114,27 @@ def __init__( # recursively reachable from this function self.reachable_internal_functions: OrderedSet[ContractFunctionT] = OrderedSet() + # writes to variables from this function + self._variable_writes: OrderedSet[VarInfo] = OrderedSet() + + # reads of variables from this function + self._variable_reads: OrderedSet[VarInfo] = OrderedSet() + + # list of modules used (accessed state) by this function + self._used_modules: OrderedSet[ModuleInfo] = OrderedSet() + # to be populated during codegen self._ir_info: Any = None self._function_id: Optional[int] = None + @property + def _variable_accesses(self): + return self._variable_reads | self._variable_writes + + @property + def modifiability(self): + return Modifiability.from_state_mutability(self.mutability) + @cached_property def call_site_kwargs(self): # special kwargs that are allowed in call site @@ -269,9 +288,11 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": if len(funcdef.body) != 1 or not isinstance(funcdef.body[0].get("value"), vy_ast.Ellipsis): raise FunctionDeclarationException( - "function body in an interface can only be ...!", funcdef + "function body in an interface can only be `...`!", funcdef ) + assert function_visibility is not None # mypy hint + return cls( funcdef.name, positional_args, @@ -314,13 +335,19 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": "Default function may not receive any arguments", funcdef.args.args[0] ) + if function_visibility == FunctionVisibility.DEPLOY and funcdef.name != "__init__": + raise FunctionDeclarationException( + "Only constructors can be marked as `@deploy`!", funcdef + ) if funcdef.name == "__init__": - if ( - state_mutability in (StateMutability.PURE, StateMutability.VIEW) - or function_visibility == FunctionVisibility.INTERNAL - ): + if state_mutability in (StateMutability.PURE, StateMutability.VIEW): raise FunctionDeclarationException( - "Constructor cannot be marked as `@pure`, `@view` or `@internal`", funcdef + "Constructor cannot be marked as `@pure` or `@view`", funcdef + ) + if function_visibility != FunctionVisibility.DEPLOY: + raise FunctionDeclarationException( + f"Constructor must be marked as `@deploy`, not `@{function_visibility}`", + funcdef, ) if return_type is not None: raise FunctionDeclarationException( @@ -333,6 +360,9 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": "Constructor may not use default arguments", funcdef.args.defaults[0] ) + # sanity check + assert function_visibility is not None + return cls( funcdef.name, positional_args, @@ -344,14 +374,11 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": ast_def=funcdef, ) - def set_reentrancy_key_position(self, position: StorageSlot) -> None: + def set_reentrancy_key_position(self, position: VarOffset) -> None: if hasattr(self, "reentrancy_key_position"): raise CompilerPanic("Position was already assigned") if self.nonreentrant is None: raise CompilerPanic(f"No reentrant key {self}") - # sanity check even though implied by the type - if position._location != DataLocation.STORAGE: - raise CompilerPanic("Non-storage reentrant key") self.reentrancy_key_position = position @classmethod @@ -456,6 +483,14 @@ def is_external(self) -> bool: def is_internal(self) -> bool: return self.visibility == FunctionVisibility.INTERNAL + @property + def is_deploy(self) -> bool: + return self.visibility == FunctionVisibility.DEPLOY + + @property + def is_constructor(self) -> bool: + return self.name == "__init__" + @property def is_mutable(self) -> bool: return self.mutability > StateMutability.VIEW @@ -464,10 +499,6 @@ def is_mutable(self) -> bool: def is_payable(self) -> bool: return self.mutability == StateMutability.PAYABLE - @property - def is_constructor(self) -> bool: - return self.name == "__init__" - @property def is_fallback(self) -> bool: return self.name == "__default__" @@ -535,20 +566,14 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: modified_line = re.sub( kwarg_pattern, kwarg.value.node_source_code, node.node_source_code ) - error_suggestion = ( - f"\n(hint: Try removing the kwarg: `{modified_line}`)" - if modified_line != node.node_source_code - else "" - ) - raise ArgumentException( - ( - "Usage of kwarg in Vyper is restricted to " - + ", ".join([f"{k}=" for k in self.call_site_kwargs.keys()]) - + f". {error_suggestion}" - ), - kwarg, - ) + msg = "Usage of kwarg in Vyper is restricted to " + msg += ", ".join([f"{k}=" for k in self.call_site_kwargs.keys()]) + + hint = None + if modified_line != node.node_source_code: + hint = f"Try removing the kwarg: `{modified_line}`" + raise ArgumentException(msg, kwarg, hint=hint) return self.return_type @@ -601,7 +626,7 @@ def _parse_return_type(funcdef: vy_ast.FunctionDef) -> Optional[VyperType]: def _parse_decorators( funcdef: vy_ast.FunctionDef, -) -> tuple[FunctionVisibility, StateMutability, Optional[str]]: +) -> tuple[Optional[FunctionVisibility], StateMutability, Optional[str]]: function_visibility = None state_mutability = None nonreentrant_key = None @@ -632,7 +657,9 @@ def _parse_decorators( if FunctionVisibility.is_valid_value(decorator.id): if function_visibility is not None: raise FunctionDeclarationException( - f"Visibility is already set to: {function_visibility}", funcdef + f"Visibility is already set to: {function_visibility}", + decorator, + hint="only one visibility decorator is allowed per function", ) function_visibility = FunctionVisibility(decorator.id) @@ -748,6 +775,10 @@ def __init__( self.return_type = return_type self.is_modifying = is_modifying + @property + def modifiability(self): + return Modifiability.MODIFIABLE if self.is_modifying else Modifiability.RUNTIME_CONSTANT + def __repr__(self): return f"{self.underlying_type._id} member function '{self.name}'" diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index ee1da22a87..86840f4f91 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Optional +from typing import TYPE_CHECKING, Optional from vyper import ast as vy_ast from vyper.abi_types import ABI_Address, ABIType @@ -16,12 +16,16 @@ validate_expected_type, validate_unique_method_ids, ) +from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.primitives import AddressT from vyper.semantics.types.user import EventT, StructT, _UserType +if TYPE_CHECKING: + from vyper.semantics.analysis.base import ModuleInfo + class InterfaceT(_UserType): _type_members = {"address": AddressT()} @@ -234,7 +238,7 @@ def from_ModuleT(cls, module_t: "ModuleT") -> "InterfaceT": for node in module_t.function_defs: func_t = node._metadata["func_type"] - if not func_t.is_external: + if not (func_t.is_external or func_t.is_constructor): continue funcs.append((node.name, func_t)) @@ -276,6 +280,12 @@ def from_InterfaceDef(cls, node: vy_ast.InterfaceDef) -> "InterfaceT": # Datatype to store all module information. class ModuleT(VyperType): _attribute_in_annotation = True + _invalid_locations = ( + DataLocation.CALLDATA, + DataLocation.CODE, + DataLocation.MEMORY, + DataLocation.TRANSIENT, + ) def __init__(self, module: vy_ast.Module, name: Optional[str] = None): super().__init__() @@ -307,7 +317,6 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None): for i in self.interface_defs: # add the type of the interface so it can be used in call position self.add_member(i.name, TYPE_T(i._metadata["interface_type"])) # type: ignore - self._helper.add_member(i.name, TYPE_T(i._metadata["interface_type"])) # type: ignore for v in self.variable_decls: self.add_member(v.target.id, v.target._metadata["varinfo"]) @@ -316,6 +325,13 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None): import_info = i._metadata["import_info"] self.add_member(import_info.alias, import_info.typ) + if hasattr(import_info.typ, "module_t"): + self._helper.add_member(import_info.alias, TYPE_T(import_info.typ)) + + for name, interface_t in self.interfaces.items(): + # can access interfaces in type position + self._helper.add_member(name, TYPE_T(interface_t)) + # __eq__ is very strict on ModuleT - object equality! this is because we # don't want to reason about where a module came from (i.e. input bundle, # search path, symlinked vs normalized path, etc.) @@ -345,27 +361,97 @@ def struct_defs(self): def interface_defs(self): return self._module.get_children(vy_ast.InterfaceDef) + @cached_property + def interfaces(self) -> dict[str, InterfaceT]: + ret = {} + for i in self.interface_defs: + assert i.name not in ret # precondition + ret[i.name] = i._metadata["interface_type"] + + for i in self.import_stmts: + import_info = i._metadata["import_info"] + if isinstance(import_info.typ, InterfaceT): + assert import_info.alias not in ret # precondition + ret[import_info.alias] = import_info.typ + + return ret + @property def import_stmts(self): return self._module.get_children((vy_ast.Import, vy_ast.ImportFrom)) + @cached_property + def imported_modules(self) -> dict[str, "ModuleInfo"]: + ret = {} + for s in self.import_stmts: + info = s._metadata["import_info"] + module_info = info.typ + if isinstance(module_info, InterfaceT): + continue + ret[info.alias] = module_info + return ret + + def find_module_info(self, needle: "ModuleT") -> Optional["ModuleInfo"]: + for s in self.imported_modules.values(): + if s.module_t == needle: + return s + return None + @property def variable_decls(self): return self._module.get_children(vy_ast.VariableDecl) + @property + def uses_decls(self): + return self._module.get_children(vy_ast.UsesDecl) + + @property + def initializes_decls(self): + return self._module.get_children(vy_ast.InitializesDecl) + + @cached_property + def used_modules(self): + # modules which are written to + ret = [] + for node in self.uses_decls: + for used_module in node._metadata["uses_info"].used_modules: + ret.append(used_module) + return ret + + @property + def initialized_modules(self): + # modules which are initialized to + ret = [] + for node in self.initializes_decls: + info = node._metadata["initializes_info"] + ret.append(info) + return ret + @cached_property def variables(self): # variables that this module defines, ex. # `x: uint256` is a private storage variable named x return {s.target.id: s.target._metadata["varinfo"] for s in self.variable_decls} + @cached_property + def functions(self): + return {f.name: f._metadata["func_type"] for f in self.function_defs} + @cached_property def immutables(self): return [t for t in self.variables.values() if t.is_immutable] @cached_property def immutable_section_bytes(self): - return sum([imm.typ.memory_bytes_required for imm in self.immutables]) + ret = 0 + for s in self.immutables: + ret += s.typ.memory_bytes_required + + for initializes_info in self.initialized_modules: + module_t = initializes_info.module_info.module_t + ret += module_t.immutable_section_bytes + + return ret @cached_property def interface(self): diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index 5564570536..c6a4531df8 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -117,16 +117,16 @@ def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType: if isinstance(node, vy_ast.Attribute): # ex. SomeModule.SomeStruct - # sanity check - we only allow modules/interfaces to be - # imported as `Name`s currently. - if not isinstance(node.value, vy_ast.Name): + if isinstance(node.value, vy_ast.Attribute): + module_or_interface = _type_from_annotation(node.value) + elif isinstance(node.value, vy_ast.Name): + try: + module_or_interface = namespace[node.value.id] # type: ignore + except UndeclaredDefinition: + raise InvalidType(err_msg, node) from None + else: raise InvalidType(err_msg, node) - try: - module_or_interface = namespace[node.value.id] # type: ignore - except UndeclaredDefinition: - raise InvalidType(err_msg, node) from None - if hasattr(module_or_interface, "module_t"): # i.e., it's a ModuleInfo module_or_interface = module_or_interface.module_t diff --git a/vyper/utils.py b/vyper/utils.py index 2349731b97..b2284eaba0 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -1,6 +1,7 @@ import binascii import contextlib import decimal +import enum import functools import sys import time @@ -8,7 +9,7 @@ import warnings from typing import Generic, List, TypeVar, Union -from vyper.exceptions import DecimalOverrideException, InvalidLiteral +from vyper.exceptions import CompilerPanic, DecimalOverrideException, InvalidLiteral _T = TypeVar("_T") @@ -62,6 +63,59 @@ def copy(self): return self.__class__(super().copy()) +class StringEnum(enum.Enum): + # Must be first, or else won't work, specifies what .value is + def _generate_next_value_(name, start, count, last_values): + return name.lower() + + # Override ValueError with our own internal exception + @classmethod + def _missing_(cls, value): + raise CompilerPanic(f"{value} is not a valid {cls.__name__}") + + @classmethod + def is_valid_value(cls, value: str) -> bool: + return value in set(o.value for o in cls) + + @classmethod + def options(cls) -> List["StringEnum"]: + return list(cls) + + @classmethod + def values(cls) -> List[str]: + return [v.value for v in cls.options()] + + # Comparison operations + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + raise CompilerPanic(f"bad comparison: ({type(other)}, {type(self)})") + return self is other + + # Python normally does __ne__(other) ==> not self.__eq__(other) + + def __lt__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + raise CompilerPanic(f"bad comparison: ({type(other)}, {type(self)})") + options = self.__class__.options() + return options.index(self) < options.index(other) # type: ignore + + def __le__(self, other: object) -> bool: + return self.__eq__(other) or self.__lt__(other) + + def __gt__(self, other: object) -> bool: + return not self.__le__(other) + + def __ge__(self, other: object) -> bool: + return not self.__lt__(other) + + def __str__(self) -> str: + return self.value + + def __hash__(self) -> int: + # let `dataclass` know that this class is not mutable + return super().__hash__() + + class DecimalContextOverride(decimal.Context): def __setattr__(self, name, value): if name == "prec": From 6b9fff2fcc032176e257e5e252c916c06b9cee3a Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 8 Feb 2024 23:14:09 -0500 Subject: [PATCH 02/12] rename validate_expected_type to infer_type and have it return a type it also tags the node with the inferred type --- .../builtins/folding/test_bitwise.py | 6 ++-- vyper/builtins/_signatures.py | 4 +-- vyper/builtins/functions.py | 11 +++--- vyper/semantics/analysis/local.py | 13 +++---- vyper/semantics/analysis/module.py | 2 +- vyper/semantics/analysis/utils.py | 35 +++++++++++-------- vyper/semantics/types/function.py | 10 +++--- vyper/semantics/types/module.py | 6 ++-- vyper/semantics/types/subscriptable.py | 8 ++--- vyper/semantics/types/user.py | 6 ++-- 10 files changed, 51 insertions(+), 50 deletions(-) diff --git a/tests/functional/builtins/folding/test_bitwise.py b/tests/functional/builtins/folding/test_bitwise.py index c1ff7674bb..892f0bcabc 100644 --- a/tests/functional/builtins/folding/test_bitwise.py +++ b/tests/functional/builtins/folding/test_bitwise.py @@ -4,7 +4,7 @@ from tests.utils import parse_and_fold from vyper.exceptions import InvalidType, OverflowException -from vyper.semantics.analysis.utils import validate_expected_type +from vyper.semantics.analysis.utils import infer_type from vyper.semantics.types.shortcuts import INT256_T, UINT256_T from vyper.utils import unsigned_to_signed @@ -55,7 +55,7 @@ def foo(a: uint256, b: uint256) -> uint256: # force bounds check, no-op because validate_numeric_bounds # already does this, but leave in for hygiene (in case # more types are added). - validate_expected_type(new_node, UINT256_T) + _ = infer_type(new_node, UINT256_T) # compile time behavior does not match runtime behavior. # compile-time will throw on OOB, runtime will wrap. except OverflowException: # here: check the wrapped value matches runtime @@ -81,7 +81,7 @@ def foo(a: int256, b: uint256) -> int256: vyper_ast = parse_and_fold(f"{a} {op} {b}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() - validate_expected_type(new_node, INT256_T) # force bounds check + _ = infer_type(new_node, INT256_T) # force bounds check # compile time behavior does not match runtime behavior. # compile-time will throw on OOB, runtime will wrap. except (InvalidType, OverflowException): diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index 6e6cf4c662..3d25b435da 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -10,7 +10,7 @@ from vyper.semantics.analysis.utils import ( check_modifiability, get_exact_type_from_node, - validate_expected_type, + infer_type, ) from vyper.semantics.types import TYPE_T, KwargSettings, VyperType from vyper.semantics.types.utils import type_from_annotation @@ -99,7 +99,7 @@ def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> N # for its side effects (will throw if is not a type) type_from_annotation(arg) else: - validate_expected_type(arg, expected_type) + infer_type(arg, expected_type) def _validate_arg_types(self, node: vy_ast.Call) -> None: num_args = len(self._inputs) # the number of args the signature indicates diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 7575f4d77e..345b59197a 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -54,7 +54,7 @@ get_common_types, get_exact_type_from_node, get_possible_types_from_node, - validate_expected_type, + infer_type, ) from vyper.semantics.types import ( TYPE_T, @@ -508,8 +508,7 @@ def infer_arg_types(self, node, expected_return_typ=None): ret = [] prev_typeclass = None for arg in node.args: - validate_expected_type(arg, (BytesT.any(), StringT.any(), BytesM_T.any())) - arg_t = get_possible_types_from_node(arg).pop() + arg_t = infer_type(arg, (BytesT.any(), StringT.any(), BytesM_T.any())) current_typeclass = "String" if isinstance(arg_t, StringT) else "Bytes" if prev_typeclass and current_typeclass != prev_typeclass: raise TypeMismatch( @@ -865,7 +864,7 @@ def infer_kwarg_types(self, node): "Output type must be one of integer, bytes32 or address", node.keywords[0].value ) output_typedef = TYPE_T(output_type) - node.keywords[0].value._metadata["type"] = output_typedef + #node.keywords[0].value._metadata["type"] = output_typedef else: output_typedef = TYPE_T(BYTES32_T) @@ -2376,8 +2375,8 @@ def infer_kwarg_types(self, node): ret = {} for kwarg in node.keywords: kwarg_name = kwarg.arg - validate_expected_type(kwarg.value, self._kwargs[kwarg_name].typ) - ret[kwarg_name] = get_exact_type_from_node(kwarg.value) + typ = infer_type(kwarg.value, self._kwargs[kwarg_name].typ) + ret[kwarg_name] = typ return ret def fetch_call_return(self, node): diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index d96215ede0..77fa57c074 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -26,7 +26,7 @@ get_exact_type_from_node, get_expr_info, get_possible_types_from_node, - validate_expected_type, + infer_type, ) from vyper.semantics.data_locations import DataLocation @@ -254,7 +254,7 @@ def _validate_revert_reason(self, msg_node: vy_ast.VyperNode) -> None: self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node)) elif not (isinstance(msg_node, vy_ast.Name) and msg_node.id == "UNREACHABLE"): try: - validate_expected_type(msg_node, StringT(1024)) + _ = infer_type(msg_node, StringT(1024)) except TypeMismatch as e: raise InvalidType("revert reason must fit within String[1024]") from e self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node)) @@ -563,15 +563,10 @@ def scope_name(self): def visit(self, node, typ): if typ is not VOID_TYPE and not isinstance(typ, TYPE_T): - validate_expected_type(node, typ) + infer_type(node, expected_type=typ) - # recurse and typecheck in case we are being fed the wrong type for - # some reason. super().visit(node, typ) - # annotate - node._metadata["type"] = typ - if not isinstance(typ, TYPE_T): info = get_expr_info(node) # get_expr_info fills in node._expr_info @@ -793,7 +788,7 @@ def visit_Tuple(self, node: vy_ast.Tuple, typ: VyperType) -> None: # don't recurse; can't annotate AST children of type definition return - # these guarantees should be provided by validate_expected_type + # these guarantees should be provided by infer_type assert isinstance(typ, TupleT) assert len(node.elements) == len(typ.member_types) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index e50c3e6d6f..787ec82c15 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -526,7 +526,7 @@ def _validate_self_namespace(): if node.is_constant: assert node.value is not None # checked in VariableDecl.validate() - ExprVisitor().visit(node.value, type_) # performs validate_expected_type + ExprVisitor().visit(node.value, type_) # performs type validation if not check_modifiability(node.value, Modifiability.CONSTANT): raise StateAccessViolation("Value must be a literal", node.value) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index f1f0f48a86..c889e6ab75 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -224,7 +224,7 @@ def types_from_BinOp(self, node): # can be different types types_list = get_possible_types_from_node(node.left) # check rhs is unsigned integer - validate_expected_type(node.right, IntegerT.unsigneds()) + _ = infer_type(node.right, IntegerT.unsigneds()) else: types_list = get_common_types(node.left, node.right) @@ -319,7 +319,7 @@ def types_from_Constant(self, node): raise InvalidLiteral(f"Could not determine type for literal value '{node.value}'", node) def types_from_IfExp(self, node): - validate_expected_type(node.test, BoolT()) + _ = infer_type(node.test, expected_type=BoolT()) types_list = get_common_types(node.body, node.orelse) if not types_list: @@ -529,14 +529,14 @@ def _validate_literal_array(node, expected): for item in node.elements: try: - validate_expected_type(item, expected.value_type) + _ = infer_type(item, expected.value_type) except (InvalidType, TypeMismatch): return False return True -def validate_expected_type(node, expected_type): +def infer_type(node, expected_type): """ Validate that the given node matches the expected type(s) @@ -551,8 +551,15 @@ def validate_expected_type(node, expected_type): Returns ------- - None + The inferred type. The inferred type must be a concrete type which + is compatible with the expected type (although the expected type may + be generic). """ + ret = _infer_type_helper(node, expected_type) + node._metadata["type"] = ret + return ret + +def _infer_type_helper(node, expected_type): if not isinstance(expected_type, tuple): expected_type = (expected_type,) @@ -561,15 +568,15 @@ def validate_expected_type(node, expected_type): for t in possible_tuple_types: if len(t.member_types) != len(node.elements): continue - for item_ast, item_type in zip(node.elements, t.member_types): + ret = [] + for item_ast, expected_item_type in zip(node.elements, t.member_types): try: - validate_expected_type(item_ast, item_type) - return + item_t = infer_type(item_ast, expected_type=expected_item_type) + ret.append(item_t) except VyperException: - pass - else: - # fail block - pass + break # go to fail block + else: + return TupleT(tuple(ret)) given_types = _ExprAnalyser().get_possible_types_from_node(node) @@ -579,11 +586,11 @@ def validate_expected_type(node, expected_type): if not isinstance(expected, (DArrayT, SArrayT)): continue if _validate_literal_array(node, expected): - return + return expected else: for given, expected in itertools.product(given_types, expected_type): if expected.compare_type(given): - return + return given # validation failed, prepare a meaningful error message if len(expected_type) > 1: diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 62f9c60585..4f4fc82e5c 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -27,7 +27,7 @@ from vyper.semantics.analysis.utils import ( check_modifiability, get_exact_type_from_node, - validate_expected_type, + infer_type, ) from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import KwargSettings, VyperType @@ -542,7 +542,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: raise CallViolation("Cannot send ether to nonpayable function", kwarg_node) for arg, expected in zip(node.args, self.argument_types): - validate_expected_type(arg, expected) + infer_type(arg, expected) # TODO this should be moved to validate_call_args for kwarg in node.keywords: @@ -553,7 +553,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: f"`{kwarg.arg}=` specified but {self.name}() does not return anything", kwarg.value, ) - validate_expected_type(kwarg.value, kwarg_settings.typ) + infer_type(kwarg.value, kwarg_settings.typ) if kwarg_settings.require_literal: if not isinstance(kwarg.value, vy_ast.Constant): raise InvalidType( @@ -730,7 +730,7 @@ def _parse_args( value = funcdef.args.defaults[i - n_positional_args] if not check_modifiability(value, Modifiability.RUNTIME_CONSTANT): raise StateAccessViolation("Value must be literal or environment variable", value) - validate_expected_type(value, type_) + infer_type(value, expected_type=type_) keyword_args.append(KeywordArg(argname, type_, value, ast_source=arg)) argnames.add(argname) @@ -788,7 +788,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: assert len(node.args) == len(self.arg_types) # validate_call_args postcondition for arg, expected_type in zip(node.args, self.arg_types): # CMC 2022-04-01 this should probably be in the validation module - validate_expected_type(arg, expected_type) + infer_type(arg, expected_type=expected_type) return self.return_type diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index 86840f4f91..0ef052a3da 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -13,7 +13,7 @@ from vyper.semantics.analysis.base import Modifiability, VarInfo from vyper.semantics.analysis.utils import ( check_modifiability, - validate_expected_type, + infer_type, validate_unique_method_ids, ) from vyper.semantics.data_locations import DataLocation @@ -83,8 +83,8 @@ def _ctor_call_return(self, node: vy_ast.Call) -> "InterfaceT": def _ctor_arg_types(self, node): validate_call_args(node, 1) - validate_expected_type(node.args[0], AddressT()) - return [AddressT()] + typ = infer_type(node.args[0], AddressT()) + return [typ] def _ctor_kwarg_types(self, node): return {} diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index 635a1631a2..9dec62e136 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -35,9 +35,9 @@ def getter_signature(self) -> Tuple[Tuple, Optional[VyperType]]: def validate_index_type(self, node): # TODO: break this cycle - from vyper.semantics.analysis.utils import validate_expected_type + from vyper.semantics.analysis.utils import infer_type - validate_expected_type(node, self.key_type) + infer_type(node, self.key_type) class HashMapT(_SubscriptableT): @@ -125,7 +125,7 @@ def count(self): def validate_index_type(self, node): # TODO break this cycle - from vyper.semantics.analysis.utils import validate_expected_type + from vyper.semantics.analysis.utils import infer_type if isinstance(node, vy_ast.Int): if node.value < 0: @@ -133,7 +133,7 @@ def validate_index_type(self, node): if node.value >= self.length: raise ArrayIndexException("Index out of range", node) - validate_expected_type(node, IntegerT.any()) + infer_type(node, IntegerT.any()) def get_subscripted_type(self, node): return self.value_type diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index 92a455e3d8..c3f169ac8d 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -16,7 +16,7 @@ ) from vyper.semantics.analysis.base import Modifiability from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions -from vyper.semantics.analysis.utils import check_modifiability, validate_expected_type +from vyper.semantics.analysis.utils import check_modifiability, infer_type from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import VyperType from vyper.semantics.types.subscriptable import HashMapT @@ -270,7 +270,7 @@ def from_EventDef(cls, base_node: vy_ast.EventDef) -> "EventT": def _ctor_call_return(self, node: vy_ast.Call) -> None: validate_call_args(node, len(self.arguments)) for arg, expected in zip(node.args, self.arguments.values()): - validate_expected_type(arg, expected) + infer_type(arg, expected) def to_toplevel_abi_dict(self) -> list[dict]: return [ @@ -412,7 +412,7 @@ def _ctor_call_return(self, node: vy_ast.Call) -> "StructT": key, ) - validate_expected_type(value, members.pop(key.id)) + infer_type(value, members.pop(key.id)) if members: raise VariableDeclarationException( From cc7c19885b217539c0045b9bd26fed2e1fe76e5e Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 11 Feb 2024 07:08:25 -0800 Subject: [PATCH 03/12] fix: fuzz test not updated to use TypeMismatch (#3768) this is a regression introduced in c6b29c7f06a; the exception thrown by `validate_expected_type()` was updated to be `TypeMismatch`, but this test was not correspondingly updated. --- tests/functional/builtins/folding/test_bitwise.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/functional/builtins/folding/test_bitwise.py b/tests/functional/builtins/folding/test_bitwise.py index c1ff7674bb..f63ef8484a 100644 --- a/tests/functional/builtins/folding/test_bitwise.py +++ b/tests/functional/builtins/folding/test_bitwise.py @@ -1,9 +1,9 @@ import pytest -from hypothesis import given, settings +from hypothesis import example, given, settings from hypothesis import strategies as st from tests.utils import parse_and_fold -from vyper.exceptions import InvalidType, OverflowException +from vyper.exceptions import OverflowException, TypeMismatch from vyper.semantics.analysis.utils import validate_expected_type from vyper.semantics.types.shortcuts import INT256_T, UINT256_T from vyper.utils import unsigned_to_signed @@ -66,9 +66,10 @@ def foo(a: uint256, b: uint256) -> uint256: @pytest.mark.fuzzing -@settings(max_examples=50) +@settings(max_examples=51) @pytest.mark.parametrize("op", ["<<", ">>"]) @given(a=st_sint256, b=st.integers(min_value=0, max_value=256)) +@example(a=128, b=248) # throws TypeMismatch def test_bitwise_shift_signed(get_contract, a, b, op): source = f""" @external @@ -84,7 +85,7 @@ def foo(a: int256, b: uint256) -> int256: validate_expected_type(new_node, INT256_T) # force bounds check # compile time behavior does not match runtime behavior. # compile-time will throw on OOB, runtime will wrap. - except (InvalidType, OverflowException): + except (TypeMismatch, OverflowException): # check the wrapped value matches runtime assert op == "<<" assert contract.foo(a, b) == unsigned_to_signed((a << b) % (2**256), 256) From 37ef8f4b54375a458e8b708cf3c41877b5f1655e Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 11 Feb 2024 07:08:53 -0800 Subject: [PATCH 04/12] chore: run mypy as part of lint rule in Makefile (#3771) and remove the separate mypy rule. this makes the development workflow a bit faster --- Makefile | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/Makefile b/Makefile index 645b800e79..649b381012 100644 --- a/Makefile +++ b/Makefile @@ -17,11 +17,8 @@ dev-init: test: pytest -mypy: - tox -e mypy - lint: - tox -e lint + tox -e lint,mypy docs: rm -f docs/vyper.rst From 261e3d9349cd8acc202b3c63f16c73ef45035c1b Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 12 Feb 2024 14:24:10 -0800 Subject: [PATCH 05/12] fix: `StringEnum._generate_next_value_ signature` (#3770) per the documentation, `_generate_next_value_` should be a staticmethod. reference: https://docs.python.org/3/library/enum.html#enum.Enum._generate_next_value_ --- vyper/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vyper/utils.py b/vyper/utils.py index b2284eaba0..ab4d789aa4 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -65,6 +65,7 @@ def copy(self): class StringEnum(enum.Enum): # Must be first, or else won't work, specifies what .value is + @staticmethod def _generate_next_value_(name, start, count, last_values): return name.lower() From a2eb60c713ee538ace46dde5c8ffbe625c1daa86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micka=C3=ABl=20Schoentgen?= Date: Mon, 12 Feb 2024 23:25:33 +0100 Subject: [PATCH 06/12] docs: adopt a new theme: `shibuya` (#3754) --- .readthedocs.yaml | 13 +- Dockerfile | 2 +- README.md | 12 +- docs/_static/css/dark.css | 215 --------------------------- docs/_static/css/toggle.css | 77 ---------- docs/_static/js/toggle.js | 26 ---- docs/built-in-functions.rst | 166 ++++++++++----------- docs/compiler-exceptions.rst | 14 +- docs/compiling-a-contract.rst | 28 ++-- docs/conf.py | 141 +++--------------- docs/constants-and-vars.rst | 6 +- docs/contributing.rst | 2 +- docs/control-structures.rst | 32 ++-- docs/event-logging.rst | 8 +- docs/index.rst | 2 +- docs/interfaces.rst | 24 +-- docs/logo.svg | 4 + docs/natspec.rst | 10 +- docs/scoping-and-declarations.rst | 32 ++-- docs/statements.rst | 16 +- docs/structure-of-a-contract.rst | 24 +-- docs/testing-contracts-brownie.rst | 9 +- docs/testing-contracts-ethtester.rst | 11 +- docs/types.rst | 20 +-- docs/vyper-by-example.rst | 78 +++++----- docs/vyper-logo-transparent.svg | 11 -- examples/tokens/ERC20.vy | 2 +- requirements-docs.txt | 4 +- tox.ini | 4 +- 29 files changed, 286 insertions(+), 707 deletions(-) delete mode 100644 docs/_static/css/dark.css delete mode 100644 docs/_static/css/toggle.css delete mode 100644 docs/_static/js/toggle.js create mode 100644 docs/logo.svg delete mode 100644 docs/vyper-logo-transparent.svg diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 1ad9000f53..e7f5fa079a 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -1,23 +1,20 @@ -# File: .readthedocs.yaml - version: 2 -# Set the version of Python and other tools you might need build: # TODO: update to `latest` once supported # https://github.com/readthedocs/readthedocs.org/issues/8861 os: ubuntu-22.04 tools: - python: "3.10" + python: "3.11" -# Build from the docs/ directory with Sphinx sphinx: configuration: docs/conf.py -formats: all - +# We can't use "all" because "htmlzip" format is broken for now +formats: + - epub + - pdf -# Optionally declare the Python requirements required to build your docs python: install: - requirements: requirements-docs.txt diff --git a/Dockerfile b/Dockerfile index bc5bb607d6..b4bfa6d3a4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,7 +6,7 @@ ARG VCS_REF LABEL org.label-schema.build-date=$BUILD_DATE \ org.label-schema.name="Vyper" \ org.label-schema.description="Vyper is an experimental programming language" \ - org.label-schema.url="https://vyper.readthedocs.io/en/latest/" \ + org.label-schema.url="https://docs.vyperlang.org/en/latest/" \ org.label-schema.vcs-ref=$VCS_REF \ org.label-schema.vcs-url="https://github.com/vyperlang/vyper" \ org.label-schema.vendor="Vyper Team" \ diff --git a/README.md b/README.md index 33c4557cc8..b14b7eaaf0 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![Build Status](https://github.com/vyperlang/vyper/workflows/Test/badge.svg)](https://github.com/vyperlang/vyper/actions/workflows/test.yml) -[![Documentation Status](https://readthedocs.org/projects/vyper/badge/?version=latest)](http://vyper.readthedocs.io/en/latest/?badge=latest "ReadTheDocs") +[![Documentation Status](https://readthedocs.org/projects/vyper/badge/?version=latest)](http://docs.vyperlang.org/en/latest/?badge=latest "ReadTheDocs") [![Discord](https://img.shields.io/discord/969926564286459934.svg?label=%23vyper)](https://discord.gg/6tw7PTM7C2) [![PyPI](https://badge.fury.io/py/vyper.svg)](https://pypi.org/project/vyper "PyPI") @@ -13,9 +13,9 @@ [![Language grade: Python](https://github.com/vyperlang/vyper/workflows/CodeQL/badge.svg)](https://github.com/vyperlang/vyper/actions/workflows/codeql.yml) # Getting Started -See [Installing Vyper](http://vyper.readthedocs.io/en/latest/installing-vyper.html) to install vyper. +See [Installing Vyper](http://docs.vyperlang.org/en/latest/installing-vyper.html) to install vyper. See [Tools and Resources](https://github.com/vyperlang/vyper/wiki/Vyper-tools-and-resources) for an additional list of framework and tools with vyper support. -See [Documentation](http://vyper.readthedocs.io/en/latest/index.html) for the documentation and overall design goals of the Vyper language. +See [Documentation](http://docs.vyperlang.org/en/latest/index.html) for the documentation and overall design goals of the Vyper language. See [Learn.Vyperlang.org](https://learn.vyperlang.org/) for **learning Vyper by building a Pokémon game**. See [try.vyperlang.org](https://try.vyperlang.org/) to use Vyper in a hosted jupyter environment! @@ -23,7 +23,7 @@ See [try.vyperlang.org](https://try.vyperlang.org/) to use Vyper in a hosted jup **Note: Vyper is beta software, use with care** # Installation -See the [Vyper documentation](https://vyper.readthedocs.io/en/latest/installing-vyper.html) +See the [Vyper documentation](https://docs.vyperlang.org/en/latest/installing-vyper.html) for build instructions. # Compiling a contract @@ -47,7 +47,7 @@ be a bit behind the latest version found in the master branch of this repository ## Testing (using pytest) -(Complete [installation steps](https://vyper.readthedocs.io/en/latest/installing-vyper.html) first.) +(Complete [installation steps](https://docs.vyperlang.org/en/latest/installing-vyper.html) first.) ```bash make dev-init @@ -75,4 +75,4 @@ To get a call graph from a python profile, https://stackoverflow.com/a/23164271/ * See Issues tab, and feel free to submit your own issues * Add PRs if you discover a solution to an existing issue * For further discussions and questions, post in [Discussions](https://github.com/vyperlang/vyper/discussions) or talk to us on [Discord](https://discord.gg/6tw7PTM7C2) -* For more information, see [Contributing](http://vyper.readthedocs.io/en/latest/contributing.html) +* For more information, see [Contributing](http://docs.vyperlang.org/en/latest/contributing.html) diff --git a/docs/_static/css/dark.css b/docs/_static/css/dark.css deleted file mode 100644 index 158f08e0fc..0000000000 --- a/docs/_static/css/dark.css +++ /dev/null @@ -1,215 +0,0 @@ -/* links */ - -a, -a:visited { - color: #aaddff; -} - - -/* code directives */ - -.method dt, -.class dt, -.data dt, -.attribute dt, -.function dt, -.classmethod dt, -.exception dt, -.descclassname, -.descname { - background-color: #2d2d2d !important; -} - -.descname { - color: inherit !important; -} - -.rst-content dl:not(.docutils) dt { - color: #aaddff; - border-top: solid 3px #525252; - border-left: solid 3px #525252; -} - -em.property { - color: #888888; -} - - -/* tables */ - -.rst-content table.docutils thead { - color: #ddd; -} - -.rst-content table.docutils td { - border: 0px; -} - -.rst-content table.docutils:not(.field-list) tr:nth-child(2n-1) td { - background-color: #5a5a5a; -} - - -/* inlined code highlights */ - -.xref, -.py-meth, -.rst-content a code { - color: #aaddff !important; - font-weight: normal !important; -} - -.rst-content code { - color: #eee !important; - font-weight: normal !important; -} - -code.literal { - background-color: #2d2d2d !important; - border: 1px solid #6d6d6d !important; -} - -code.docutils.literal.notranslate { - color: #ddd; -} - - -/* code examples */ - -pre { - background: #222; - color: #ddd; - font-size: 150%; - border-color: #333 !important; -} - -.copybutton { - color: #666 !important; - border-color: #333 !important; -} - -.highlight .go, -.highlight .nb, -.highlight .kn { - /* text */ - color: #ddd; - font-weight: normal; -} - -.highlight .o, -.highlight .p { - /* comparators, parentheses */ - color: #bbb; -} - -.highlight .c1 { - /* comments */ - color: #888; -} - -.highlight .bp { - /* self */ - color: #fc3; -} - -.highlight .mf, -.highlight .mi, -.highlight .kc { - /* numbers, booleans */ - color: #c90; -} - -.highlight .gt, -.highlight .nf, -.highlight .fm { - /* functions */ - color: #7cf; -} - -.highlight .nd { - /* decorators */ - color: #f66; -} - -.highlight .k, -.highlight .ow { - /* statements */ - color: #A7F; - font-weight: normal; -} - -.highlight .s2, -.highlight .s1, -.highlight .nt { - /* strings */ - color: #5d6; -} - - -/* notes, warnings, hints */ - -.hint .admonition-title { - background: #2aa87c !important; -} - -.warning .admonition-title { - background: #cc4444 !important; -} - -.admonition-title { - background: #3a7ca8 !important; -} - -.admonition, -.note { - background-color: #2d2d2d !important; -} - - -/* table of contents */ - -.wy-body-for-nav { - background-color: rgb(26, 28, 29); -} - -.wy-nav-content-wrap { - background-color: rgba(0, 0, 0, 0.6) !important; -} - -.sidebar { - background-color: #191919 !important; -} - -.sidebar-title { - background-color: #2b2b2b !important; -} - -.wy-menu-vertical a { - color: #ddd; -} - -.wy-menu-vertical code.docutils.literal.notranslate { - color: #404040; - background: none !important; - border: none !important; -} - -.wy-nav-content { - background: #3c3c3c; - color: #dddddd; -} - -.wy-menu-vertical li.on a, -.wy-menu-vertical li.current>a { - background: #a3a3a3; - border-bottom: 0px !important; - border-top: 0px !important; -} - -.wy-menu-vertical li.current { - background: #b3b3b3; -} - -.toc-backref { - color: grey !important; -} \ No newline at end of file diff --git a/docs/_static/css/toggle.css b/docs/_static/css/toggle.css deleted file mode 100644 index ebbd0658a1..0000000000 --- a/docs/_static/css/toggle.css +++ /dev/null @@ -1,77 +0,0 @@ -input[type=checkbox] { - visibility: hidden; - height: 0; - width: 0; - margin: 0; -} - -.rst-versions .rst-current-version { - padding: 10px; - display: flex; - justify-content: space-between; -} - -.rst-versions .rst-current-version .fa-book, -.rst-versions .rst-current-version .fa-v, -.rst-versions .rst-current-version .fa-caret-down { - height: 24px; - line-height: 24px; - vertical-align: middle; -} - -.rst-versions .rst-current-version .fa-element { - width: 80px; - text-align: center; -} - -.rst-versions .rst-current-version .fa-book { - text-align: left; -} - -.rst-versions .rst-current-version .fa-v { - color: #27AE60; - text-align: right; -} - -label { - margin: 0 auto; - display: inline-block; - justify-content: center; - align-items: right; - border-radius: 100px; - position: relative; - cursor: pointer; - text-indent: -9999px; - width: 50px; - height: 21px; - background: #000; -} - -label:after { - border-radius: 50%; - position: absolute; - content: ''; - background: #fff; - width: 15px; - height: 15px; - top: 3px; - left: 3px; - transition: ease-in-out 200ms; -} - -input:checked+label { - background: #3a7ca8; -} - -input:checked+label:after { - left: calc(100% - 5px); - transform: translateX(-100%); -} - -html.transition, -html.transition *, -html.transition *:before, -html.transition *:after { - transition: ease-in-out 200ms !important; - transition-delay: 0 !important; -} \ No newline at end of file diff --git a/docs/_static/js/toggle.js b/docs/_static/js/toggle.js deleted file mode 100644 index df131042b5..0000000000 --- a/docs/_static/js/toggle.js +++ /dev/null @@ -1,26 +0,0 @@ -document.addEventListener('DOMContentLoaded', function() { - - var checkbox = document.querySelector('input[name=mode]'); - - function toggleCssMode(isDay) { - var mode = (isDay ? "Day" : "Night"); - localStorage.setItem("css-mode", mode); - - var darksheet = $('link[href="_static/css/dark.css"]')[0].sheet; - darksheet.disabled = isDay; - } - - if (localStorage.getItem("css-mode") == "Day") { - toggleCssMode(true); - checkbox.setAttribute('checked', true); - } - - checkbox.addEventListener('change', function() { - document.documentElement.classList.add('transition'); - window.setTimeout(() => { - document.documentElement.classList.remove('transition'); - }, 1000) - toggleCssMode(this.checked); - }) - -}); \ No newline at end of file diff --git a/docs/built-in-functions.rst b/docs/built-in-functions.rst index 45cf9ec8c2..afb64e71ca 100644 --- a/docs/built-in-functions.rst +++ b/docs/built-in-functions.rst @@ -14,14 +14,14 @@ Bitwise Operations Perform a "bitwise and" operation. Each bit of the output is 1 if the corresponding bit of ``x`` AND of ``y`` is 1, otherwise it is 0. - .. code-block:: python + .. code-block:: vyper @external @view def foo(x: uint256, y: uint256) -> uint256: return bitwise_and(x, y) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(31337, 8008135) 12353 @@ -34,14 +34,14 @@ Bitwise Operations Return the bitwise complement of ``x`` - the number you get by switching each 1 for a 0 and each 0 for a 1. - .. code-block:: python + .. code-block:: vyper @external @view def foo(x: uint256) -> uint256: return bitwise_not(x) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(0) 115792089237316195423570985008687907853269984665640564039457584007913129639935 @@ -54,14 +54,14 @@ Bitwise Operations Perform a "bitwise or" operation. Each bit of the output is 0 if the corresponding bit of ``x`` AND of ``y`` is 0, otherwise it is 1. - .. code-block:: python + .. code-block:: vyper @external @view def foo(x: uint256, y: uint256) -> uint256: return bitwise_or(x, y) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(31337, 8008135) 8027119 @@ -74,14 +74,14 @@ Bitwise Operations Perform a "bitwise exclusive or" operation. Each bit of the output is the same as the corresponding bit in ``x`` if that bit in ``y`` is 0, and it is the complement of the bit in ``x`` if that bit in ``y`` is 1. - .. code-block:: python + .. code-block:: vyper @external @view def foo(x: uint256, y: uint256) -> uint256: return bitwise_xor(x, y) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(31337, 8008135) 8014766 @@ -94,14 +94,14 @@ Bitwise Operations Return ``x`` with the bits shifted ``_shift`` places. A positive ``_shift`` value equals a left shift, a negative value is a right shift. - .. code-block:: python + .. code-block:: vyper @external @view def foo(x: uint256, y: int128) -> uint256: return shift(x, y) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(2, 8) 512 @@ -144,7 +144,7 @@ Vyper has three built-ins for contract creation; all three contract creation bui Returns the address of the newly created proxy contract. If the create operation fails (for instance, in the case of a ``CREATE2`` collision), execution will revert. - .. code-block:: python + .. code-block:: vyper @external def foo(target: address) -> address: @@ -173,7 +173,7 @@ Vyper has three built-ins for contract creation; all three contract creation bui Returns the address of the created contract. If the create operation fails (for instance, in the case of a ``CREATE2`` collision), execution will revert. If there is no code at ``target``, execution will revert. - .. code-block:: python + .. code-block:: vyper @external def foo(target: address) -> address: @@ -197,7 +197,7 @@ Vyper has three built-ins for contract creation; all three contract creation bui Returns the address of the created contract. If the create operation fails (for instance, in the case of a ``CREATE2`` collision), execution will revert. If ``code_offset >= target.codesize`` (ex. if there is no code at ``target``), execution will revert. - .. code-block:: python + .. code-block:: vyper @external def foo(blueprint: address) -> address: @@ -213,7 +213,7 @@ Vyper has three built-ins for contract creation; all three contract creation bui It is recommended to deploy blueprints with the ERC-5202 preamble ``0xFE7100`` to guard them from being called as regular contracts. This is particularly important for factories where the constructor has side effects (including ``SELFDESTRUCT``!), as those could get executed by *anybody* calling the blueprint contract directly. The ``code_offset=`` kwarg is provided to enable this pattern: - .. code-block:: python + .. code-block:: vyper @external def foo(blueprint: address) -> address: @@ -241,7 +241,7 @@ Vyper has three built-ins for contract creation; all three contract creation bui Returns ``success`` in a tuple with return value if ``revert_on_failure`` is set to ``False``. - .. code-block:: python + .. code-block:: vyper @external @payable @@ -276,7 +276,7 @@ Vyper has three built-ins for contract creation; all three contract creation bui * ``topics``: List of ``bytes32`` log topics. The length of this array determines which opcode is used. * ``data``: Unindexed event data to include in the log. May be given as ``Bytes`` or ``bytes32``. - .. code-block:: python + .. code-block:: vyper @external def foo(_topic: bytes32, _data: Bytes[100]): @@ -288,7 +288,7 @@ Vyper has three built-ins for contract creation; all three contract creation bui * ``data``: Data representing the error message causing the revert. - .. code-block:: python + .. code-block:: vyper @external def foo(_data: Bytes[100]): @@ -308,7 +308,7 @@ Vyper has three built-ins for contract creation; all three contract creation bui This function has been deprecated from version 0.3.8 onwards. The underlying opcode will eventually undergo breaking changes, and its use is not recommended. - .. code-block:: python + .. code-block:: vyper @external def do_the_needful(): @@ -326,7 +326,7 @@ Vyper has three built-ins for contract creation; all three contract creation bui The amount to send is always specified in ``wei``. - .. code-block:: python + .. code-block:: vyper @external def foo(_receiver: address, _amount: uint256, gas: uint256): @@ -339,14 +339,14 @@ Cryptography Take two points on the Alt-BN128 curve and add them together. - .. code-block:: python + .. code-block:: vyper @external @view def foo(x: uint256[2], y: uint256[2]) -> uint256[2]: return ecadd(x, y) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo([1, 2], [1, 2]) [ @@ -361,14 +361,14 @@ Cryptography * ``point``: Point to be multiplied * ``scalar``: Scalar value - .. code-block:: python + .. code-block:: vyper @external @view def foo(point: uint256[2], scalar: uint256) -> uint256[2]: return ecmul(point, scalar) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo([1, 2], 3) [ @@ -390,7 +390,7 @@ Cryptography Prior to Vyper ``0.3.10``, the ``ecrecover`` function could return an undefined (possibly nonzero) value for invalid inputs to ``ecrecover``. For more information, please see `GHSA-f5x6-7qgp-jhf3 `_. - .. code-block:: python + .. code-block:: vyper @external @view @@ -402,7 +402,7 @@ Cryptography @view def foo(hash: bytes32, v: uint256, r:uint256, s:uint256) -> address: return ecrecover(hash, v, r, s) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo('0x6c9c5e133b8aafb2ea74f524a5263495e7ae5701c7248805f7b511d973dc7055', 28, @@ -417,14 +417,14 @@ Cryptography * ``_value``: Value to hash. Can be a ``String``, ``Bytes``, or ``bytes32``. - .. code-block:: python + .. code-block:: vyper @external @view def foo(_value: Bytes[100]) -> bytes32 return keccak256(_value) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(b"potato") 0x9e159dfcfe557cc1ca6c716e87af98fdcb94cd8c832386d0429b2b7bec02754f @@ -435,14 +435,14 @@ Cryptography * ``_value``: Value to hash. Can be a ``String``, ``Bytes``, or ``bytes32``. - .. code-block:: python + .. code-block:: vyper @external @view def foo(_value: Bytes[100]) -> bytes32 return sha256(_value) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(b"potato") 0xe91c254ad58860a02c788dfb5c1a65d6a8846ab1dc649631c7db16fef4af2dec @@ -456,14 +456,14 @@ Data Manipulation If the input arguments are ``String`` the return type is ``String``. Otherwise the return type is ``Bytes``. - .. code-block:: python + .. code-block:: vyper @external @view def foo(a: String[5], b: String[5], c: String[5]) -> String[100]: return concat(a, " ", b, " ", c, "!") - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo("why","hello","there") "why hello there!" @@ -487,14 +487,14 @@ Data Manipulation Returns the string representation of ``value``. - .. code-block:: python + .. code-block:: vyper @external @view def foo(b: uint256) -> String[78]: return uint2str(b) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(420) "420" @@ -509,14 +509,14 @@ Data Manipulation Returns a value of the type specified by ``output_type``. - .. code-block:: python + .. code-block:: vyper @external @view def foo(b: Bytes[32]) -> address: return extract32(b, 0, output_type=address) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo("0x0000000000000000000000009f8F72aA9304c8B593d555F12eF6589cC3A579A2") "0x9f8F72aA9304c8B593d555F12eF6589cC3A579A2" @@ -531,14 +531,14 @@ Data Manipulation If the value being sliced is a ``Bytes`` or ``bytes32``, the return type is ``Bytes``. If it is a ``String``, the return type is ``String``. - .. code-block:: python + .. code-block:: vyper @external @view def foo(s: String[32]) -> String[5]: return slice(s, 4, 5) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo("why hello! how are you?") "hello" @@ -552,14 +552,14 @@ Math * ``value``: Integer to return the absolute value of - .. code-block:: python + .. code-block:: vyper @external @view def foo(value: int256) -> int256: return abs(value) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(-31337) 31337 @@ -570,14 +570,14 @@ Math * ``value``: Decimal value to round up - .. code-block:: python + .. code-block:: vyper @external @view def foo(x: decimal) -> int256: return ceil(x) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(3.1337) 4 @@ -588,14 +588,14 @@ Math * ``typename``: Name of the decimal type (currently only ``decimal``) - .. code-block:: python + .. code-block:: vyper @external @view def foo() -> decimal: return epsilon(decimal) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo() Decimal('1E-10') @@ -606,14 +606,14 @@ Math * ``value``: Decimal value to round down - .. code-block:: python + .. code-block:: vyper @external @view def foo(x: decimal) -> int256: return floor(x) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(3.1337) 3 @@ -622,14 +622,14 @@ Math Return the greater value of ``a`` and ``b``. The input values may be any numeric type as long as they are both of the same type. The output value is of the same type as the input values. - .. code-block:: python + .. code-block:: vyper @external @view def foo(a: uint256, b: uint256) -> uint256: return max(a, b) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(23, 42) 42 @@ -638,14 +638,14 @@ Math Returns the maximum value of the numeric type specified by ``type_`` (e.g., ``int128``, ``uint256``, ``decimal``). - .. code-block:: python + .. code-block:: vyper @external @view def foo() -> int256: return max_value(int256) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo() 57896044618658097711785492504343953926634992332820282019728792003956564819967 @@ -654,14 +654,14 @@ Math Returns the lesser value of ``a`` and ``b``. The input values may be any numeric type as long as they are both of the same type. The output value is of the same type as the input values. - .. code-block:: python + .. code-block:: vyper @external @view def foo(a: uint256, b: uint256) -> uint256: return min(a, b) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(23, 42) 23 @@ -670,14 +670,14 @@ Math Returns the minimum value of the numeric type specified by ``type_`` (e.g., ``int128``, ``uint256``, ``decimal``). - .. code-block:: python + .. code-block:: vyper @external @view def foo() -> int256: return min_value(int256) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo() -57896044618658097711785492504343953926634992332820282019728792003956564819968 @@ -688,14 +688,14 @@ Math This method is used to perform exponentiation without overflow checks. - .. code-block:: python + .. code-block:: vyper @external @view def foo(a: uint256, b: uint256) -> uint256: return pow_mod256(a, b) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(2, 3) 8 @@ -706,14 +706,14 @@ Math Return the square root of the provided decimal number, using the Babylonian square root algorithm. - .. code-block:: python + .. code-block:: vyper @external @view def foo(d: decimal) -> decimal: return sqrt(d) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(9.0) 3.0 @@ -722,14 +722,14 @@ Math Return the (integer) square root of the provided integer number, using the Babylonian square root algorithm. The rounding mode is to round down to the nearest integer. For instance, ``isqrt(101) == 10``. - .. code-block:: python + .. code-block:: vyper @external @view def foo(x: uint256) -> uint256: return isqrt(x) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(101) 10 @@ -738,14 +738,14 @@ Math Return the modulo of ``(a + b) % c``. Reverts if ``c == 0``. As this built-in function is intended to provides access to the underlying ``ADDMOD`` opcode, all intermediate calculations of this operation are not subject to the ``2 ** 256`` modulo according to the EVM specifications. - .. code-block:: python + .. code-block:: vyper @external @view def foo(a: uint256, b: uint256, c: uint256) -> uint256: return uint256_addmod(a, b, c) - .. code-block:: python + .. code-block:: vyper >>> (6 + 13) % 8 3 @@ -756,14 +756,14 @@ Math Return the modulo from ``(a * b) % c``. Reverts if ``c == 0``. As this built-in function is intended to provides access to the underlying ``MULMOD`` opcode, all intermediate calculations of this operation are not subject to the ``2 ** 256`` modulo according to the EVM specifications. - .. code-block:: python + .. code-block:: vyper @external @view def foo(a: uint256, b: uint256, c: uint256) -> uint256: return uint256_mulmod(a, b, c) - .. code-block:: python + .. code-block:: vyper >>> (11 * 2) % 5 2 @@ -774,7 +774,7 @@ Math Add ``x`` and ``y``, without checking for overflow. ``x`` and ``y`` must both be integers of the same type. If the result exceeds the bounds of the input type, it will be wrapped. - .. code-block:: python + .. code-block:: vyper @external @view @@ -787,7 +787,7 @@ Math return unsafe_add(x, y) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(1, 1) 2 @@ -805,7 +805,7 @@ Math Subtract ``x`` and ``y``, without checking for overflow. ``x`` and ``y`` must both be integers of the same type. If the result underflows the bounds of the input type, it will be wrapped. - .. code-block:: python + .. code-block:: vyper @external @view @@ -818,7 +818,7 @@ Math return unsafe_sub(x, y) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(4, 3) 1 @@ -837,7 +837,7 @@ Math Multiply ``x`` and ``y``, without checking for overflow. ``x`` and ``y`` must both be integers of the same type. If the result exceeds the bounds of the input type, it will be wrapped. - .. code-block:: python + .. code-block:: vyper @external @view @@ -850,7 +850,7 @@ Math return unsafe_mul(x, y) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(1, 1) 1 @@ -872,7 +872,7 @@ Math Divide ``x`` and ``y``, without checking for division-by-zero. ``x`` and ``y`` must both be integers of the same type. If the denominator is zero, the result will (following EVM semantics) be zero. - .. code-block:: python + .. code-block:: vyper @external @view @@ -885,7 +885,7 @@ Math return unsafe_div(x, y) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(1, 1) 1 @@ -910,14 +910,14 @@ Utilities * ``_value``: Value for the ether unit. Any numeric type may be used, however the value cannot be negative. * ``unit``: Ether unit name (e.g. ``"wei"``, ``"ether"``, ``"gwei"``, etc.) indicating the denomination of ``_value``. Must be given as a literal string. - .. code-block:: python + .. code-block:: vyper @external @view def foo(s: String[32]) -> uint256: return as_wei_value(1.337, "ether") - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(1) 1337000000000000000 @@ -930,14 +930,14 @@ Utilities The EVM only provides access to the most recent 256 blocks. This function reverts if the block number is greater than or equal to the current block number or more than 256 blocks behind the current block. - .. code-block:: python + .. code-block:: vyper @external @view def foo() -> bytes32: return blockhash(block.number - 16) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo() 0xf3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 @@ -948,7 +948,7 @@ Utilities * ``typename``: Name of the type, except ``HashMap[_KeyType, _ValueType]`` - .. code-block:: python + .. code-block:: vyper @external @view @@ -959,14 +959,14 @@ Utilities Return the length of a given ``Bytes``, ``String`` or ``DynArray[_Type, _Integer]``. - .. code-block:: python + .. code-block:: vyper @external @view def foo(s: String[32]) -> uint256: return len(s) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo("hello") 5 @@ -980,14 +980,14 @@ Utilities Returns a value of the type specified by ``output_type``. - .. code-block:: python + .. code-block:: vyper @external @view def foo() -> Bytes[4]: return method_id('transfer(address,uint256)', output_type=Bytes[4]) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo() 0xa9059cbb @@ -1003,7 +1003,7 @@ Utilities Returns a bytestring whose max length is determined by the arguments. For example, encoding a ``Bytes[32]`` results in a ``Bytes[64]`` (first word is the length of the bytestring variable). - .. code-block:: python + .. code-block:: vyper @external @view @@ -1012,7 +1012,7 @@ Utilities y: Bytes[32] = b"234" return _abi_encode(x, y, method_id=method_id("foo()")) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo().hex() "c2985578" @@ -1033,7 +1033,7 @@ Utilities Returns the decoded value(s), with type as specified by `output_type`. - .. code-block:: python + .. code-block:: vyper @external @view diff --git a/docs/compiler-exceptions.rst b/docs/compiler-exceptions.rst index 395ce448ed..29b8b5c96e 100644 --- a/docs/compiler-exceptions.rst +++ b/docs/compiler-exceptions.rst @@ -58,7 +58,7 @@ of the error within the code: Raises when no valid type can be found for a literal value. - .. code-block:: python + .. code-block:: vyper @external def foo(): @@ -70,7 +70,7 @@ of the error within the code: Raises when using an invalid operator for a given type. - .. code-block:: python + .. code-block:: vyper @external def foo(): @@ -82,7 +82,7 @@ of the error within the code: Raises on an invalid reference to an existing definition. - .. code-block:: python + .. code-block:: vyper baz: int128 @@ -96,7 +96,7 @@ of the error within the code: Raises when using an invalid literal value for the given type. - .. code-block:: python + .. code-block:: vyper @external def foo(): @@ -132,7 +132,7 @@ of the error within the code: Raises when attempting to access ``msg.value`` from within a function that has not been marked as ``@payable``. - .. code-block:: python + .. code-block:: vyper @public def _foo(): @@ -174,7 +174,7 @@ of the error within the code: Raises when attempting to perform an action between two or more objects with known, dislike types. - .. code-block:: python + .. code-block:: vyper @external def foo(: @@ -215,7 +215,7 @@ CompilerPanic .. py:exception:: CompilerPanic - :: + .. code:: shell $ vyper v.vy Error compiling: v.vy diff --git a/docs/compiling-a-contract.rst b/docs/compiling-a-contract.rst index b529d1efb1..2b069c2add 100644 --- a/docs/compiling-a-contract.rst +++ b/docs/compiling-a-contract.rst @@ -20,20 +20,20 @@ vyper To compile a contract: -:: +.. code:: shell $ vyper yourFileName.vy Include the ``-f`` flag to specify which output formats to return. Use ``vyper --help`` for a full list of output options. -:: +.. code:: shell $ vyper -f abi,bytecode,bytecode_runtime,ir,asm,source_map,method_identifiers yourFileName.vy The ``-p`` flag allows you to set a root path that is used when searching for interface files to import. If none is given, it will default to the current working directory. See :ref:`searching_for_imports` for more information. -:: +.. code:: shell $ vyper -p yourProject yourProject/yourFileName.vy @@ -45,7 +45,7 @@ Storage Layout To display the default storage layout for a contract: -:: +.. code:: shell $ vyper -f layout yourFileName.vy @@ -53,7 +53,7 @@ This outputs a JSON object detailing the locations for all state variables as de To override the default storage layout for a contract: -:: +.. code:: shell $ vyper --storage-layout-file storageLayout.json yourFileName.vy @@ -69,19 +69,19 @@ vyper-json To compile from JSON supplied via ``stdin``: -:: +.. code:: shell $ vyper-json To compile from a JSON file: -:: +.. code:: shell $ vyper-json yourProject.json By default, the output is sent to ``stdout``. To redirect to a file, use the ``-o`` flag: -:: +.. code:: shell $ vyper-json -o compiled.json @@ -143,7 +143,7 @@ When you compile your contract code, you can specify the target Ethereum Virtual For instance, the adding the following pragma to a contract indicates that it should be compiled for the "shanghai" fork of the EVM. -.. code-block:: python +.. code-block:: vyper #pragma evm-version shanghai @@ -153,13 +153,13 @@ For instance, the adding the following pragma to a contract indicates that it sh When compiling via the ``vyper`` CLI, you can specify the EVM version option using the ``--evm-version`` flag: -:: +.. code:: shell $ vyper --evm-version [VERSION] When using the JSON interface, you can include the ``"evmVersion"`` key within the ``"settings"`` field: -.. code-block:: javascript +.. code-block:: json { "settings": { @@ -200,8 +200,6 @@ The following is a list of supported EVM versions, and changes in the compiler i - The ``MCOPY`` opcode will be generated automatically by the compiler for most memory operations. - - Compiler Input and Output JSON Description ========================================== @@ -216,7 +214,7 @@ Input JSON Description The following example describes the expected input format of ``vyper-json``. Comments are of course not permitted and used here *only for explanatory purposes*. -.. code-block:: javascript +.. code-block:: json { // Required: Source code language. Must be set to "Vyper". @@ -294,7 +292,7 @@ Output JSON Description The following example describes the output format of ``vyper-json``. Comments are of course not permitted and used here *only for explanatory purposes*. -.. code-block:: javascript +.. code-block:: json { // The compiler version used to generate the JSON diff --git a/docs/conf.py b/docs/conf.py index 5dc1eee8f5..99ffe35a63 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,57 +1,12 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# # Vyper documentation build configuration file, created by # sphinx-quickstart on Wed Jul 26 11:18:29 2017. -# -# This file is execfile()d with the current directory set to its -# containing dir. -# -# Note that not all possible configuration values are present in this -# autogenerated file. -# -# All configuration values have a default; values that are commented out -# serve to show the default. - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -# import os -# import sys -# sys.path.insert(0, os.path.abspath('.')) -from recommonmark.parser import CommonMarkParser - -# TO DO - Create and Implement Vyper Lexer -# def setup(sphinx): -# sys.path.insert(0, os.path.abspath('./utils')) -# from SolidityLexer import SolidityLexer -# sphinx.add_lexer('Python', SolidityLexer()) - - -# -- General configuration ------------------------------------------------ - -# If your documentation needs a minimal Sphinx version, state it here. -# -# needs_sphinx = '1.0' -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. extensions = [ - "sphinx.ext.autodoc", + "sphinx_copybutton", "sphinx.ext.intersphinx", ] -# Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] - -# The suffix(es) of source filenames. -# You can specify multiple suffix as a list of string: -# -source_suffix = [".rst", ".md"] - -# The master toctree document. master_doc = "toctree" # General information about the project. @@ -59,68 +14,31 @@ copyright = "2017-2024 CC-BY-4.0 Vyper Team" author = "Vyper Team (originally created by Vitalik Buterin)" -# The version info for the project you're documenting, acts as replacement for -# |version| and |release|, also used in various other places throughout the -# built documents. -# -# The short X.Y version. -version = "" -# The full version, including alpha/beta/rc tags. -release = "" - # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = "python" - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] - -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = "sphinx" - -# If true, `todo` and `todoList` produce output, else they produce nothing. -todo_include_todos = False - +language = "vyper" # -- Options for HTML output ---------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = "sphinx_rtd_theme" - -# Theme options are theme-specific and customize the look and feel of a theme -# further. For a list of options available for each theme, see the -# documentation. -# -# html_theme_options = {} - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ["_static"] - -html_css_files = ["css/toggle.css", "css/dark.css"] - -html_js_files = ["js/toggle.js"] - -html_logo = "vyper-logo-transparent.svg" - -# Custom sidebar templates, must be a dictionary that maps document names -# to template names. -# -# The default sidebars (for documents that don't match any pattern) are -# defined by theme itself. Builtin themes are using these templates by -# default: ``['localtoc.html', 'relations.html', 'sourcelink.html', -# 'searchbox.html']``. -# -# html_sidebars = {} - +html_theme = "shibuya" +html_theme_options = { + "accent_color": "purple", + "twitter_creator": "vyperlang", + "twitter_site": "vyperlang", + "twitter_url": "https://twitter.com/vyperlang", + "github_url": "https://github.com/vyperlang", +} +html_favicon = "logo.svg" +html_logo = "logo.svg" + +# For the "Edit this page ->" link +html_context = { + "source_type": "github", + "source_user": "vyperlang", + "source_repo": "vyper", +} # -- Options for HTMLHelp output ------------------------------------------ @@ -130,21 +48,6 @@ # -- Options for LaTeX output --------------------------------------------- -latex_elements: dict = { - # The paper size ('letterpaper' or 'a4paper'). - # - # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). - # - # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. - # - # 'preamble': '', - # Latex figure (float) alignment - # - # 'figure_align': 'htbp', -} - # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). @@ -153,7 +56,7 @@ master_doc, "Vyper.tex", "Vyper Documentation", - "Vyper Team (originally created by Vitalik Buterin)", + author, "manual", ), ] @@ -183,10 +86,6 @@ ), ] -source_parsers = { - ".md": CommonMarkParser, -} - intersphinx_mapping = { "brownie": ("https://eth-brownie.readthedocs.io/en/stable", None), "pytest": ("https://docs.pytest.org/en/latest/", None), diff --git a/docs/constants-and-vars.rst b/docs/constants-and-vars.rst index 7f9c1408c5..00ce7a8ccc 100644 --- a/docs/constants-and-vars.rst +++ b/docs/constants-and-vars.rst @@ -56,7 +56,7 @@ Accessing State Variables ``self`` is used to access a contract's :ref:`state variables`, as shown in the following example: -.. code-block:: python +.. code-block:: vyper state_var: uint256 @@ -76,7 +76,7 @@ Calling Internal Functions ``self`` is also used to call :ref:`internal functions` within a contract: -.. code-block:: python +.. code-block:: vyper @internal def _times_two(amount: uint256) -> uint256: @@ -93,7 +93,7 @@ Custom Constants Custom constants can be defined at a global level in Vyper. To define a constant, make use of the ``constant`` keyword. -.. code-block:: python +.. code-block:: vyper TOTAL_SUPPLY: constant(uint256) = 10000000 total_supply: public(uint256) diff --git a/docs/contributing.rst b/docs/contributing.rst index 221600f930..55b2694424 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -5,7 +5,7 @@ Contributing Help is always appreciated! -To get started, you can try `installing Vyper `_ in order to familiarize +To get started, you can try `installing Vyper `_ in order to familiarize yourself with the components of Vyper and the build process. Also, it may be useful to become well-versed at writing smart-contracts in Vyper. diff --git a/docs/control-structures.rst b/docs/control-structures.rst index 14202cbae7..a0aa927261 100644 --- a/docs/control-structures.rst +++ b/docs/control-structures.rst @@ -10,7 +10,7 @@ Functions Functions are executable units of code within a contract. Functions may only be declared within a contract's :ref:`module scope `. -.. code-block:: python +.. code-block:: vyper @external def bid(): @@ -30,7 +30,7 @@ External Functions External functions (marked with the ``@external`` decorator) are a part of the contract interface and may only be called via transactions or from other contracts. -.. code-block:: python +.. code-block:: vyper @external def add_seven(a: int128) -> int128: @@ -52,7 +52,7 @@ Internal Functions Internal functions (marked with the ``@internal`` decorator) are only accessible from other functions within the same contract. They are called via the :ref:`self` object: -.. code-block:: python +.. code-block:: vyper @internal def _times_two(amount: uint256, two: uint256 = 2) -> uint256: @@ -77,7 +77,7 @@ You can optionally declare a function's mutability by using a :ref:`decorator )`` decorator places a lock on a function, and all functions with the same ```` value. An attempt by an external contract to call back into any of these functions causes the transaction to revert. -.. code-block:: python +.. code-block:: vyper @external @nonreentrant("lock") @@ -133,7 +133,7 @@ This function is always named ``__default__``. It must be annotated with ``@exte If the function is annotated as ``@payable``, this function is executed whenever the contract is sent Ether (without data). This is why the default function cannot accept arguments - it is a design decision of Ethereum to make no differentiation between sending ether to a contract or a user address. -.. code-block:: python +.. code-block:: vyper event Payment: amount: uint256 @@ -169,7 +169,7 @@ The ``__init__`` Function ``__init__`` is a special initialization function that may only be called at the time of deploying a contract. It can be used to set initial values for storage variables. A common use case is to set an ``owner`` variable with the creator the contract: -.. code-block:: python +.. code-block:: vyper owner: address @@ -202,7 +202,7 @@ Decorator Description The ``if`` statement is a control flow construct used for conditional execution: -.. code-block:: python +.. code-block:: vyper if CONDITION: ... @@ -213,7 +213,7 @@ Note that unlike Python, Vyper does not allow implicit conversion from non-boole You can also include ``elif`` and ``else`` statements, to add more conditional statements and a body that executes when the conditionals are false: -.. code-block:: python +.. code-block:: vyper if CONDITION: ... @@ -227,7 +227,7 @@ You can also include ``elif`` and ``else`` statements, to add more conditional s The ``for`` statement is a control flow construct used to iterate over a value: -.. code-block:: python +.. code-block:: vyper for i in : ... @@ -239,7 +239,7 @@ Array Iteration You can use ``for`` to iterate through the values of any array variable: -.. code-block:: python +.. code-block:: vyper foo: int128[3] = [4, 23, 42] for i in foo: @@ -249,7 +249,7 @@ In the above, example, the loop executes three times with ``i`` assigned the val You can also iterate over a literal array, as long as a common type can be determined for each item in the array: -.. code-block:: python +.. code-block:: vyper for i in [4, 23, 42]: ... @@ -264,14 +264,14 @@ Range Iteration Ranges are created using the ``range`` function. The following examples are valid uses of ``range``: -.. code-block:: python +.. code-block:: vyper for i in range(STOP): ... ``STOP`` is a literal integer greater than zero. ``i`` begins as zero and increments by one until it is equal to ``STOP``. -.. code-block:: python +.. code-block:: vyper for i in range(stop, bound=N): ... @@ -280,7 +280,7 @@ Here, ``stop`` can be a variable with integer type, greater than zero. ``N`` mus Another use of range can be with ``START`` and ``STOP`` bounds. -.. code-block:: python +.. code-block:: vyper for i in range(START, STOP): ... @@ -291,7 +291,7 @@ Finally, it is possible to use ``range`` with runtime `start` and `stop` values In this case, Vyper checks at runtime that `end - start <= bound`. ``N`` must be a compile-time constant. -.. code-block:: python +.. code-block:: vyper for i in range(start, end, bound=N): ... diff --git a/docs/event-logging.rst b/docs/event-logging.rst index 904b179e70..4f350d6459 100644 --- a/docs/event-logging.rst +++ b/docs/event-logging.rst @@ -10,7 +10,7 @@ Example of Logging This example is taken from the `sample ERC20 contract `_ and shows the basic flow of event logging: -.. code-block:: python +.. code-block:: vyper # Events of the token. event Transfer: @@ -59,7 +59,7 @@ Declaring Events Let's look at an event declaration in more detail. -.. code-block:: python +.. code-block:: vyper event Transfer: sender: indexed(address) @@ -81,7 +81,7 @@ Event declarations look similar to struct declarations, containing one or more a Note that the first topic of a log record consists of the signature of the name of the event that occurred, including the types of its parameters. It is also possible to create an event with no arguments. In this case, use the ``pass`` statement: -.. code-block:: python +.. code-block:: vyper event Foo: pass @@ -92,7 +92,7 @@ Once an event is declared, you can log (send) events. You can send events as man Logging events is done using the ``log`` statement: -.. code-block:: python +.. code-block:: vyper log Transfer(msg.sender, _to, _amount) diff --git a/docs/index.rst b/docs/index.rst index 69d818cd69..8ee48cdb83 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,4 +1,4 @@ -.. image:: vyper-logo-transparent.svg +.. image:: logo.svg :width: 140px :alt: Vyper logo :align: center diff --git a/docs/interfaces.rst b/docs/interfaces.rst index ab220272d8..803b9daf18 100644 --- a/docs/interfaces.rst +++ b/docs/interfaces.rst @@ -12,7 +12,7 @@ Interfaces can be added to contracts either through inline definition, or by imp The ``interface`` keyword is used to define an inline external interface: -.. code-block:: python +.. code-block:: vyper interface FooBar: def calculate() -> uint256: view @@ -20,7 +20,7 @@ The ``interface`` keyword is used to define an inline external interface: The defined interface can then be used to make external calls, given a contract address: -.. code-block:: python +.. code-block:: vyper @external def test(foobar: FooBar): @@ -28,7 +28,7 @@ The defined interface can then be used to make external calls, given a contract The interface name can also be used as a type annotation for storage variables. You then assign an address value to the variable to access that interface. Note that casting an address to an interface is possible, e.g. ``FooBar()``: -.. code-block:: python +.. code-block:: vyper foobar_contract: FooBar @@ -42,7 +42,7 @@ The interface name can also be used as a type annotation for storage variables. Specifying ``payable`` or ``nonpayable`` annotation indicates that the call made to the external contract will be able to alter storage, whereas the ``view`` ``pure`` call will use a ``STATICCALL`` ensuring no storage can be altered during execution. Additionally, ``payable`` allows non-zero value to be sent along with the call. -.. code-block:: python +.. code-block:: vyper interface FooBar: def calculate() -> uint256: pure @@ -70,7 +70,7 @@ Keyword Description The ``default_return_value`` parameter can be used to handle ERC20 tokens affected by the missing return value bug in a way similar to OpenZeppelin's ``safeTransfer`` for Solidity: -.. code-block:: python +.. code-block:: vyper ERC20(USDT).transfer(msg.sender, 1, default_return_value=True) # returns True ERC20(USDT).transfer(msg.sender, 1) # reverts because nothing returned @@ -86,7 +86,7 @@ Interfaces are imported with ``import`` or ``from ... import`` statements. Imported interfaces are written using standard Vyper syntax. The body of each function is ignored when the interface is imported. If you are defining a standalone interface, it is normally specified by using a ``pass`` statement: -.. code-block:: python +.. code-block:: vyper @external def test1(): @@ -98,7 +98,7 @@ Imported interfaces are written using standard Vyper syntax. The body of each fu You can also import a fully implemented contract and Vyper will automatically convert it to an interface. It is even possible for a contract to import itself to gain access to its own interface. -.. code-block:: python +.. code-block:: vyper import greeter as Greeter @@ -118,7 +118,7 @@ Imports via ``import`` With absolute ``import`` statements, you **must** include an alias as a name for the imported package. In the following example, failing to include ``as Foo`` will raise a compile error: -.. code-block:: python +.. code-block:: vyper import contract.foo as Foo @@ -127,7 +127,7 @@ Imports via ``from ... import`` Using ``from`` you can perform both absolute and relative imports. You may optionally include an alias - if you do not, the name of the interface will be the same as the file. -.. code-block:: python +.. code-block:: vyper # without an alias from contract import foo @@ -137,7 +137,7 @@ Using ``from`` you can perform both absolute and relative imports. You may optio Relative imports are possible by prepending dots to the contract name. A single leading dot indicates a relative import starting with the current package. Two leading dots indicate a relative import from the parent of the current package: -.. code-block:: python +.. code-block:: vyper from . import foo from ..interfaces import baz @@ -162,7 +162,7 @@ Built-in Interfaces Vyper includes common built-in interfaces such as `ERC20 `_ and `ERC721 `_. These are imported from ``ethereum.ercs``: -.. code-block:: python +.. code-block:: vyper from ethereum.ercs import ERC20 @@ -175,7 +175,7 @@ Implementing an Interface You can define an interface for your contract with the ``implements`` statement: -.. code-block:: python +.. code-block:: vyper import an_interface as FooBarInterface diff --git a/docs/logo.svg b/docs/logo.svg new file mode 100644 index 0000000000..d2c666074a --- /dev/null +++ b/docs/logo.svg @@ -0,0 +1,4 @@ + + + + diff --git a/docs/natspec.rst b/docs/natspec.rst index a6c2d932e4..90ad5d39b4 100644 --- a/docs/natspec.rst +++ b/docs/natspec.rst @@ -17,7 +17,7 @@ Vyper supports structured documentation for contracts and external functions usi The compiler does not parse docstrings of internal functions. You are welcome to NatSpec in comments for internal functions, however they are not processed or included in the compiler output. -.. code-block:: python +.. code-block:: vyper """ @title A simulator for Bug Bunny, the most famous Rabbit @@ -72,16 +72,16 @@ When parsed by the compiler, documentation such as the one from the above exampl If the above contract is saved as ``carrots.vy`` then you can generate the documentation using: -.. code:: +.. code:: shell - vyper -f userdoc,devdoc carrots.vy + $ vyper -f userdoc,devdoc carrots.vy User Documentation ------------------ The above documentation will produce the following user documentation JSON as output: -.. code-block:: javascript +.. code-block:: json { "methods": { @@ -102,7 +102,7 @@ Developer Documentation Apart from the user documentation file, a developer documentation JSON file should also be produced and should look like this: -.. code-block:: javascript +.. code-block:: json { "author": "Warned Bros", diff --git a/docs/scoping-and-declarations.rst b/docs/scoping-and-declarations.rst index 7165ec6e4d..838720c25b 100644 --- a/docs/scoping-and-declarations.rst +++ b/docs/scoping-and-declarations.rst @@ -8,7 +8,7 @@ Variable Declaration The first time a variable is referenced you must declare its :ref:`type `: -.. code-block:: python +.. code-block:: vyper data: int128 @@ -25,7 +25,7 @@ Declaring Public Variables Storage variables can be marked as ``public`` during declaration: -.. code-block:: python +.. code-block:: vyper data: public(int128) @@ -38,7 +38,7 @@ Declaring Immutable Variables Variables can be marked as ``immutable`` during declaration: -.. code-block:: python +.. code-block:: vyper DATA: immutable(uint256) @@ -55,7 +55,7 @@ Tuple Assignment You cannot directly declare tuple types. However, in certain cases you can use literal tuples during assignment. For example, when a function returns multiple values: -.. code-block:: python +.. code-block:: vyper @internal def foo() -> (int128, int128): @@ -84,13 +84,13 @@ This can be performed when compiling via ``vyper`` by including the ``--storage For example, consider upgrading the following contract: -.. code-block:: python +.. code-block:: vyper # old_contract.vy owner: public(address) balanceOf: public(HashMap[address, uint256]) -.. code-block:: python +.. code-block:: vyper # new_contract.vy owner: public(address) @@ -101,7 +101,7 @@ This would cause an issue when upgrading, as the ``balanceOf`` mapping would be This issue can be avoided by allocating ``balanceOf`` to ``slot1`` using the storage layout overrides. The contract can be compiled with ``vyper new_contract.vy --storage-layout-file new_contract_storage.json`` where ``new_contract_storage.json`` contains the following: -.. code-block:: javascript +.. code-block:: json { "owner": {"type": "address", "slot": 0}, @@ -130,7 +130,7 @@ Accessing Module Scope from Functions Values that are declared in the module scope of a contract, such as storage variables and functions, are accessed via the ``self`` object: -.. code-block:: python +.. code-block:: vyper a: int128 @@ -148,7 +148,7 @@ Name Shadowing It is not permitted for a memory or calldata variable to shadow the name of an immutable or constant value. The following examples will not compile: -.. code-block:: python +.. code-block:: vyper a: constant(bool) = True @@ -157,7 +157,7 @@ It is not permitted for a memory or calldata variable to shadow the name of an i # memory variable cannot have the same name as a constant or immutable variable a: bool = False return a -.. code-block:: python +.. code-block:: vyper a: immutable(bool) @@ -174,7 +174,7 @@ Function Scope Variables that are declared within a function, or given as function input arguments, are visible within the body of that function. For example, the following contract is valid because each declaration of ``a`` only exists within one function's body. -.. code-block:: python +.. code-block:: vyper @external def foo(a: int128): @@ -190,14 +190,14 @@ Variables that are declared within a function, or given as function input argume The following examples will not compile: -.. code-block:: python +.. code-block:: vyper @external def foo(a: int128): # `a` has already been declared as an input argument a: int128 = 21 -.. code-block:: python +.. code-block:: vyper @external def foo(a: int128): @@ -215,7 +215,7 @@ Block Scopes Logical blocks created by ``for`` and ``if`` statements have their own scope. For example, the following contract is valid because ``x`` only exists within the block scopes for each branch of the ``if`` statement: -.. code-block:: python +.. code-block:: vyper @external def foo(a: bool) -> int128: @@ -226,7 +226,7 @@ Logical blocks created by ``for`` and ``if`` statements have their own scope. Fo In a ``for`` statement, the target variable exists within the scope of the loop. For example, the following contract is valid because ``i`` is no longer available upon exiting the loop: -.. code-block:: python +.. code-block:: vyper @external def foo(a: bool) -> int128: @@ -236,7 +236,7 @@ In a ``for`` statement, the target variable exists within the scope of the loop. The following contract fails to compile because ``a`` has not been declared outside of the loop. -.. code-block:: python +.. code-block:: vyper @external def foo(a: bool) -> int128: diff --git a/docs/statements.rst b/docs/statements.rst index 02854adffd..34f15828a1 100644 --- a/docs/statements.rst +++ b/docs/statements.rst @@ -13,7 +13,7 @@ break The ``break`` statement terminates the nearest enclosing ``for`` loop. -.. code-block:: python +.. code-block:: vyper for i in [1, 2, 3, 4, 5]: if i == a: @@ -26,7 +26,7 @@ continue The ``continue`` statement begins the next cycle of the nearest enclosing ``for`` loop. -.. code-block:: python +.. code-block:: vyper for i in [1, 2, 3, 4, 5]: if i != a: @@ -40,7 +40,7 @@ pass ``pass`` is a null operation — when it is executed, nothing happens. It is useful as a placeholder when a statement is required syntactically, but no code needs to be executed: -.. code-block:: python +.. code-block:: vyper # this function does nothing (yet!) @@ -53,7 +53,7 @@ return ``return`` leaves the current function call with the expression list (or None) as a return value. -.. code-block:: python +.. code-block:: vyper return RETURN_VALUE @@ -69,7 +69,7 @@ log The ``log`` statement is used to log an event: -.. code-block:: python +.. code-block:: vyper log MyEvent(...) @@ -89,7 +89,7 @@ raise The ``raise`` statement triggers an exception and reverts the current call. -.. code-block:: python +.. code-block:: vyper raise "something went wrong" @@ -100,7 +100,7 @@ assert The ``assert`` statement makes an assertion about a given condition. If the condition evaluates falsely, the transaction is reverted. -.. code-block:: python +.. code-block:: vyper assert x > 5, "value too low" @@ -108,7 +108,7 @@ The error string is not required. If it is provided, it is limited to 1024 bytes This method's behavior is equivalent to: -.. code-block:: python +.. code-block:: vyper if not cond: raise "reason" diff --git a/docs/structure-of-a-contract.rst b/docs/structure-of-a-contract.rst index 3861bf4380..561f3000dd 100644 --- a/docs/structure-of-a-contract.rst +++ b/docs/structure-of-a-contract.rst @@ -10,7 +10,7 @@ This section provides a quick overview of the types of data present within a con .. _structure-versions: Pragmas -============== +======= Vyper supports several source code directives to control compiler modes and help with build reproducibility. @@ -21,7 +21,7 @@ The version pragma ensures that a contract is only compiled by the intended comp As of 0.3.10, the recommended way to specify the version pragma is as follows: -.. code-block:: python +.. code-block:: vyper #pragma version ^0.3.0 @@ -31,7 +31,7 @@ As of 0.3.10, the recommended way to specify the version pragma is as follows: The following declaration is equivalent, and, prior to 0.3.10, was the only supported method to specify the compiler version: -.. code-block:: python +.. code-block:: vyper # @version ^0.3.0 @@ -43,7 +43,7 @@ Optimization Mode The optimization mode can be one of ``"none"``, ``"codesize"``, or ``"gas"`` (default). For example, adding the following line to a contract will cause it to try to optimize for codesize: -.. code-block:: python +.. code-block:: vyper #pragma optimize codesize @@ -62,13 +62,13 @@ State Variables State variables are values which are permanently stored in contract storage. They are declared outside of the body of any functions, and initially contain the :ref:`default value` for their type. -.. code-block:: python +.. code-block:: vyper storedData: int128 State variables are accessed via the :ref:`self` object. -.. code-block:: python +.. code-block:: vyper self.storedData = 123 @@ -81,7 +81,7 @@ Functions Functions are executable units of code within a contract. -.. code-block:: python +.. code-block:: vyper @external def bid(): @@ -96,7 +96,7 @@ Events Events provide an interface for the EVM's logging facilities. Events may be logged with specially indexed data structures that allow clients, including light clients, to efficiently search for them. -.. code-block:: python +.. code-block:: vyper event Payment: amount: int128 @@ -119,19 +119,19 @@ An interface is a set of function definitions used to enable calls between smart Interfaces can be added to contracts either through inline definition, or by importing them from a separate file. -.. code-block:: python +.. code-block:: vyper interface FooBar: def calculate() -> uint256: view def test1(): nonpayable -.. code-block:: python +.. code-block:: vyper from foo import FooBar Once defined, an interface can then be used to make external calls to a given address: -.. code-block:: python +.. code-block:: vyper @external def test(some_address: address): @@ -144,7 +144,7 @@ Structs A struct is a custom defined type that allows you to group several variables together: -.. code-block:: python +.. code-block:: vyper struct MyStruct: value1: int128 diff --git a/docs/testing-contracts-brownie.rst b/docs/testing-contracts-brownie.rst index bff871d38a..46d8df6ea6 100644 --- a/docs/testing-contracts-brownie.rst +++ b/docs/testing-contracts-brownie.rst @@ -12,7 +12,7 @@ Getting Started In order to use Brownie for testing you must first `initialize a new project `_. Create a new directory for the project, and from within that directory type: -:: +.. code:: shell $ brownie init @@ -24,12 +24,14 @@ Writing a Basic Test Assume the following simple contract ``Storage.vy``. It has a single integer variable and a function to set that value. .. literalinclude:: ../examples/storage/storage.vy - :language: python + :caption: storage.vy + :language: vyper :linenos: We create a test file ``tests/test_storage.py`` where we write our tests in pytest style. .. code-block:: python + :caption: test_storage.py :linenos: import pytest @@ -70,9 +72,10 @@ In this example we are using two fixtures which are provided by Brownie: Testing Events ============== -For the remaining examples, we expand our simple storage contract to include an event and two conditions for a failed transaction: ``AdvancedStorage.vy`` +For the remaining examples, we expand our simple storage contract to include an event and two conditions for a failed transaction: ``advanced_storage.vy`` .. literalinclude:: ../examples/storage/advanced_storage.vy + :caption: advanced_storage.vy :linenos: :language: python diff --git a/docs/testing-contracts-ethtester.rst b/docs/testing-contracts-ethtester.rst index 27e67831de..92522a1eca 100644 --- a/docs/testing-contracts-ethtester.rst +++ b/docs/testing-contracts-ethtester.rst @@ -17,6 +17,7 @@ Prior to testing, the Vyper specific contract conversion and the blockchain rela Since the testing is done in the pytest framework, you can make use of `pytest.ini, tox.ini and setup.cfg `_ and you can use most IDEs' pytest plugins. .. literalinclude:: ../tests/conftest.py + :caption: conftest.py :language: python :linenos: @@ -30,12 +31,14 @@ Writing a Basic Test Assume the following simple contract ``storage.vy``. It has a single integer variable and a function to set that value. .. literalinclude:: ../examples/storage/storage.vy + :caption: storage.vy :linenos: - :language: python + :language: vyper We create a test file ``test_storage.py`` where we write our tests in pytest style. .. literalinclude:: ../tests/functional/examples/storage/test_storage.py + :caption: test_storage.py :linenos: :language: python @@ -50,18 +53,21 @@ Events and Failed Transactions To test events and failed transactions we expand our simple storage contract to include an event and two conditions for a failed transaction: ``advanced_storage.vy`` .. literalinclude:: ../examples/storage/advanced_storage.vy + :caption: advanced_storage.vy :linenos: - :language: python + :language: vyper Next, we take a look at the two fixtures that will allow us to read the event logs and to check for failed transactions. .. literalinclude:: ../tests/conftest.py + :caption: conftest.py :language: python :pyobject: tx_failed The fixture to assert failed transactions defaults to check for a ``TransactionFailed`` exception, but can be used to check for different exceptions too, as shown below. Also note that the chain gets reverted to the state before the failed transaction. .. literalinclude:: ../tests/conftest.py + :caption: conftest.py :language: python :pyobject: get_logs @@ -70,5 +76,6 @@ This fixture will return a tuple with all the logs for a certain event and trans Finally, we create a new file ``test_advanced_storage.py`` where we use the new fixtures to test failed transactions and events. .. literalinclude:: ../tests/functional/examples/storage/test_advanced_storage.py + :caption: test_advanced_storage.py :linenos: :language: python diff --git a/docs/types.rst b/docs/types.rst index 0f5bfe7b04..38779c2a4b 100644 --- a/docs/types.rst +++ b/docs/types.rst @@ -358,7 +358,7 @@ On the ABI level the Fixed-size bytes array is annotated as ``bytes``. Bytes literals may be given as bytes strings. -.. code-block:: python +.. code-block:: vyper bytes_string: Bytes[100] = b"\x01" @@ -372,7 +372,7 @@ Strings Fixed-size strings can hold strings with equal or fewer characters than the maximum length of the string. On the ABI level the Fixed-size bytes array is annotated as ``string``. -.. code-block:: python +.. code-block:: vyper example_str: String[100] = "Test String" @@ -384,7 +384,7 @@ Flags Flags are custom defined types. A flag must have at least one member, and can hold up to a maximum of 256 members. The members are represented by ``uint256`` values in the form of 2\ :sup:`n` where ``n`` is the index of the member in the range ``0 <= n <= 255``. -.. code-block:: python +.. code-block:: vyper # Defining a flag with two members flag Roles: @@ -430,7 +430,7 @@ Flag members can be combined using the above bitwise operators. While flag membe The ``in`` and ``not in`` operators can be used in conjunction with flag member combinations to check for membership. -.. code-block:: python +.. code-block:: vyper flag Roles: MANAGER @@ -491,7 +491,7 @@ Fixed-size lists hold a finite number of elements which belong to a specified ty Lists can be declared with ``_name: _ValueType[_Integer]``, except ``Bytes[N]``, ``String[N]`` and flags. -.. code-block:: python +.. code-block:: vyper # Defining a list exampleList: int128[3] @@ -507,7 +507,7 @@ Multidimensional lists are also possible. The notation for the declaration is re A two dimensional list can be declared with ``_name: _ValueType[inner_size][outer_size]``. Elements can be accessed with ``_name[outer_index][inner_index]``. -.. code-block:: python +.. code-block:: vyper # Defining a list with 2 rows and 5 columns and set all values to 0 exampleList2D: int128[5][2] = empty(int128[5][2]) @@ -531,7 +531,7 @@ Dynamic Arrays Dynamic arrays represent bounded arrays whose length can be modified at runtime, up to a bound specified in the type. They can be declared with ``_name: DynArray[_Type, _Integer]``, where ``_Type`` can be of value type or reference type (except mappings). -.. code-block:: python +.. code-block:: vyper # Defining a list exampleList: DynArray[int128, 3] @@ -558,7 +558,7 @@ Dynamic arrays represent bounded arrays whose length can be modified at runtime, .. note:: To keep code easy to reason about, modifying an array while using it as an iterator is disallowed by the language. For instance, the following usage is not allowed: - .. code-block:: python + .. code-block:: vyper for item in self.my_array: self.my_array[0] = item @@ -580,7 +580,7 @@ Struct types can be used inside mappings and arrays. Structs can contain arrays Struct members can be accessed via ``struct.argname``. -.. code-block:: python +.. code-block:: vyper # Defining a struct struct MyStruct: @@ -610,7 +610,7 @@ Mapping types are declared as ``HashMap[_KeyType, _ValueType]``. .. note:: Mappings are only allowed as state variables. -.. code-block:: python +.. code-block:: vyper # Defining a mapping exampleMapping: HashMap[int128, decimal] diff --git a/docs/vyper-by-example.rst b/docs/vyper-by-example.rst index b07842cd25..61b5e51c41 100644 --- a/docs/vyper-by-example.rst +++ b/docs/vyper-by-example.rst @@ -19,7 +19,7 @@ period ends, a predetermined beneficiary will receive the amount of the highest bid. .. literalinclude:: ../examples/auctions/simple_open_auction.vy - :language: python + :language: vyper :linenos: As you can see, this example only has a constructor, two methods to call, and @@ -29,7 +29,7 @@ need for a basic implementation of an auction smart contract. Let's get started! .. literalinclude:: ../examples/auctions/simple_open_auction.vy - :language: python + :language: vyper :lineno-start: 3 :lines: 3-17 @@ -54,7 +54,7 @@ within the same contract. The ``public`` function additionally creates a Now, the constructor. .. literalinclude:: ../examples/auctions/simple_open_auction.vy - :language: python + :language: vyper :lineno-start: 22 :lines: 22-27 @@ -72,7 +72,7 @@ caller as we will soon see. With initial setup out of the way, lets look at how our users can make bids. .. literalinclude:: ../examples/auctions/simple_open_auction.vy - :language: python + :language: vyper :lineno-start: 33 :lines: 33-46 @@ -95,7 +95,7 @@ We will send back the previous ``highestBid`` to the previous ``highestBidder`` our new ``highestBid`` and ``highestBidder``. .. literalinclude:: ../examples/auctions/simple_open_auction.vy - :language: python + :language: vyper :lineno-start: 60 :lines: 60-85 @@ -141,13 +141,13 @@ Solidity, this blind auction allows for an auction where there is no time pressu .. _counterpart: https://solidity.readthedocs.io/en/v0.5.0/solidity-by-example.html#id2 .. literalinclude:: ../examples/auctions/blind_auction.vy - :language: python + :language: vyper :linenos: While this blind auction is almost functionally identical to the blind auction implemented in Solidity, the differences in their implementations help illustrate the differences between Solidity and Vyper. .. literalinclude:: ../examples/auctions/blind_auction.vy - :language: python + :language: vyper :lineno-start: 28 :lines: 28-30 @@ -184,14 +184,14 @@ we want to explore one way how an escrow system can be implemented trustlessly. Let's go! .. literalinclude:: ../examples/safe_remote_purchase/safe_remote_purchase.vy - :language: python + :language: vyper :linenos: This is also a moderately short contract, however a little more complex in logic. Let's break down this contract bit by bit. .. literalinclude:: ../examples/safe_remote_purchase/safe_remote_purchase.vy - :language: python + :language: vyper :lineno-start: 16 :lines: 16-19 @@ -200,7 +200,7 @@ their respective data types. Remember that the ``public`` function allows the variables to be *readable* by an external caller, but not *writeable*. .. literalinclude:: ../examples/safe_remote_purchase/safe_remote_purchase.vy - :language: python + :language: vyper :lineno-start: 22 :lines: 22-29 @@ -215,7 +215,7 @@ in the contract variable ``self.value`` and saves the contract creator into ``True``. .. literalinclude:: ../examples/safe_remote_purchase/safe_remote_purchase.vy - :language: python + :language: vyper :lineno-start: 31 :lines: 31-36 @@ -231,7 +231,7 @@ contract will call the ``selfdestruct()`` function and refunds the seller and subsequently destroys the contract. .. literalinclude:: ../examples/safe_remote_purchase/safe_remote_purchase.vy - :language: python + :language: vyper :lineno-start: 38 :lines: 38-45 @@ -244,7 +244,7 @@ contract has a balance equal to 4 times the item value and the seller must send the item to the buyer. .. literalinclude:: ../examples/safe_remote_purchase/safe_remote_purchase.vy - :language: python + :language: vyper :lineno-start: 47 :lines: 47-61 @@ -276,14 +276,14 @@ Participants will be refunded their respective contributions if the total funding does not reach its target goal. .. literalinclude:: ../examples/crowdfund.vy - :language: python + :language: vyper :linenos: Most of this code should be relatively straightforward after going through our previous examples. Let's dive right in. .. literalinclude:: ../examples/crowdfund.vy - :language: python + :language: vyper :lineno-start: 3 :lines: 3-13 @@ -304,7 +304,7 @@ once the crowdfunding period is over—as determined by the ``deadline`` and of all participants. .. literalinclude:: ../examples/crowdfund.vy - :language: python + :language: vyper :lineno-start: 9 :lines: 9-15 @@ -317,7 +317,7 @@ a definitive end time for the crowdfunding period. Now lets take a look at how a person can participate in the crowdfund. .. literalinclude:: ../examples/crowdfund.vy - :language: python + :language: vyper :lineno-start: 17 :lines: 17-23 @@ -331,7 +331,7 @@ mapping, ``self.nextFunderIndex`` increments appropriately to properly index each participant. .. literalinclude:: ../examples/crowdfund.vy - :language: python + :language: vyper :lineno-start: 25 :lines: 25-31 @@ -352,7 +352,7 @@ crowdfunding campaign isn't successful? We're going to need a way to refund all the participants. .. literalinclude:: ../examples/crowdfund.vy - :language: python + :language: vyper :lineno-start: 33 :lines: 33-42 @@ -374,14 +374,14 @@ determined upon calling the ``winningProposals()`` method, which iterates throug all the proposals and returns the one with the greatest number of votes. .. literalinclude:: ../examples/voting/ballot.vy - :language: python + :language: vyper :linenos: As we can see, this is the contract of moderate length which we will dissect section by section. Let’s begin! .. literalinclude:: ../examples/voting/ballot.vy - :language: python + :language: vyper :lineno-start: 3 :lines: 3-25 @@ -402,7 +402,7 @@ their respective datatypes. Let’s move onto the constructor. .. literalinclude:: ../examples/voting/ballot.vy - :language: python + :language: vyper :lineno-start: 53 :lines: 53-62 @@ -421,7 +421,7 @@ their respective index in the original array as its key. Now that the initial setup is done, lets take a look at the functionality. .. literalinclude:: ../examples/voting/ballot.vy - :language: python + :language: vyper :lineno-start: 66 :lines: 66-75 @@ -437,7 +437,7 @@ voting power, we will set their ``weight`` to ``1`` and we will keep track of th total number of voters by incrementing ``voterCount``. .. literalinclude:: ../examples/voting/ballot.vy - :language: python + :language: vyper :lineno-start: 120 :lines: 120-135 @@ -452,7 +452,7 @@ the delegate had already voted or increase the delegate’s vote ``weight`` if the delegate has not yet voted. .. literalinclude:: ../examples/voting/ballot.vy - :language: python + :language: vyper :lineno-start: 139 :lines: 139-151 @@ -472,7 +472,7 @@ costs gas. By having the ``@view`` decorator, we let the EVM know that this is a read-only function and we benefit by saving gas fees. .. literalinclude:: ../examples/voting/ballot.vy - :language: python + :language: vyper :lineno-start: 153 :lines: 153-170 @@ -484,7 +484,7 @@ respectively by looping through all the proposals. ``winningProposal()`` is an external function allowing access to ``_winningProposal()``. .. literalinclude:: ../examples/voting/ballot.vy - :language: python + :language: vyper :lineno-start: 175 :lines: 175-178 @@ -515,7 +515,7 @@ contract, holds all shares of the company at first but can sell them all. Let's get started. .. literalinclude:: ../examples/stock/company.vy - :language: python + :language: vyper :linenos: .. note:: Throughout this contract, we use a pattern where ``@external`` functions return data from ``@internal`` functions that have the same name prepended with an underscore. This is because Vyper does not allow calls between external functions within the same contract. The internal function handles the logic, while the external function acts as a getter to allow viewing. @@ -526,7 +526,7 @@ that the contract logs. We then declare our global variables, followed by function definitions. .. literalinclude:: ../examples/stock/company.vy - :language: python + :language: vyper :lineno-start: 3 :lines: 3-27 @@ -537,7 +537,7 @@ represents the wei value of a share and ``holdings`` is a mapping that maps an address to the number of shares the address owns. .. literalinclude:: ../examples/stock/company.vy - :language: python + :language: vyper :lineno-start: 29 :lines: 29-40 @@ -548,7 +548,7 @@ company's address is initialized to hold all shares of the company in the ``holdings`` mapping. .. literalinclude:: ../examples/stock/company.vy - :language: python + :language: vyper :lineno-start: 42 :lines: 42-46 @@ -567,7 +567,7 @@ Now, lets take a look at a method that lets a person buy stock from the company's holding. .. literalinclude:: ../examples/stock/company.vy - :language: python + :language: vyper :lineno-start: 51 :lines: 51-64 @@ -579,7 +579,7 @@ and transferred to the sender's in the ``holdings`` mapping. Now that people can buy shares, how do we check someone's holdings? .. literalinclude:: ../examples/stock/company.vy - :language: python + :language: vyper :lineno-start: 66 :lines: 66-71 @@ -588,7 +588,7 @@ and returns its corresponding stock holdings by keying into ``self.holdings``. Again, an external function ``getHolding()`` is included to allow access. .. literalinclude:: ../examples/stock/company.vy - :language: python + :language: vyper :lineno-start: 72 :lines: 72-76 @@ -596,7 +596,7 @@ To check the ether balance of the company, we can simply call the getter method ``cash()``. .. literalinclude:: ../examples/stock/company.vy - :language: python + :language: vyper :lineno-start: 78 :lines: 78-95 @@ -609,7 +609,7 @@ ether to complete the sale. If all conditions are met, the holdings are deducted from the seller and given to the company. The ethers are then sent to the seller. .. literalinclude:: ../examples/stock/company.vy - :language: python + :language: vyper :lineno-start: 97 :lines: 97-110 @@ -620,7 +620,7 @@ than ``0`` and ``asserts`` whether the sender has enough stocks to send. If both conditions are satisfied, the transfer is made. .. literalinclude:: ../examples/stock/company.vy - :language: python + :language: vyper :lineno-start: 112 :lines: 112-124 @@ -632,7 +632,7 @@ enough funds to pay the amount. If both conditions satisfy, the contract sends its ether to an address. .. literalinclude:: ../examples/stock/company.vy - :language: python + :language: vyper :lineno-start: 126 :lines: 126-130 @@ -641,7 +641,7 @@ shares the company has sold and the price of each share. Internally, we get this value by calling the ``_debt()`` method. Externally it is accessed via ``debt()``. .. literalinclude:: ../examples/stock/company.vy - :language: python + :language: vyper :lineno-start: 132 :lines: 132-138 diff --git a/docs/vyper-logo-transparent.svg b/docs/vyper-logo-transparent.svg deleted file mode 100644 index 18bf3c25e2..0000000000 --- a/docs/vyper-logo-transparent.svg +++ /dev/null @@ -1,11 +0,0 @@ - diff --git a/examples/tokens/ERC20.vy b/examples/tokens/ERC20.vy index 0e94b32b9d..a9d41cbf69 100644 --- a/examples/tokens/ERC20.vy +++ b/examples/tokens/ERC20.vy @@ -31,7 +31,7 @@ decimals: public(uint8) # NOTE: By declaring `balanceOf` as public, vyper automatically generates a 'balanceOf()' getter # method to allow access to account balances. # The _KeyType will become a required parameter for the getter and it will return _ValueType. -# See: https://vyper.readthedocs.io/en/v0.1.0-beta.8/types.html?highlight=getter#mappings +# See: https://docs.vyperlang.org/en/v0.1.0-beta.8/types.html?highlight=getter#mappings balanceOf: public(HashMap[address, uint256]) # By declaring `allowance` as public, vyper automatically generates the `allowance()` getter allowance: public(HashMap[address, HashMap[address, uint256]]) diff --git a/requirements-docs.txt b/requirements-docs.txt index 5906384fc7..5c19ca7cfd 100644 --- a/requirements-docs.txt +++ b/requirements-docs.txt @@ -1,3 +1,3 @@ +shibuya==2024.1.17 sphinx==7.2.6 -recommonmark==0.7.1 -sphinx_rtd_theme==2.0.0 +sphinx-copybutton==0.5.2 diff --git a/tox.ini b/tox.ini index f9d4c3b60b..b42a13a0ab 100644 --- a/tox.ini +++ b/tox.ini @@ -19,9 +19,9 @@ whitelist_externals = make [testenv:docs] basepython=python3 deps = + shibuya sphinx - sphinx_rtd_theme - recommonmark + sphinx-copybutton commands = sphinx-build {posargs:-E} -b html docs dist/docs -n -q --color From a8c6ea284e85348d76be91e1ac53d92180fcf7b0 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 12 Feb 2024 16:42:45 -0800 Subject: [PATCH 07/12] chore: improve some error messages (#3775) * fix error message for `implements: module` currently, the compiler will panic when it encounters this case. add a suggestion to rename the interface file to `.vyi`. also catch all invalid types with a compiler panic. * add a helpful hint for imports from `vyper.interfaces` hint to try `ethereum.ercs` --- vyper/semantics/analysis/module.py | 14 ++++++++++++-- vyper/semantics/types/utils.py | 7 +++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index e50c3e6d6f..9304eb3ded 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -342,7 +342,14 @@ def visit_ImplementsDecl(self, node): type_ = type_from_annotation(node.annotation) if not isinstance(type_, InterfaceT): - raise StructureException("not an interface!", node.annotation) + msg = "Not an interface!" + hint = None + if isinstance(type_, ModuleT): + path = type_._module.path + msg += " (Since vyper v0.4.0, interface files are required" + msg += " to have a .vyi suffix.)" + hint = f"try renaming `{path}` to `{path}i`" + raise StructureException(msg, node.annotation, hint=hint) type_.validate_implements(node) @@ -627,6 +634,9 @@ def _load_import(self, node: vy_ast.VyperNode, level: int, module_str: str, alia def _load_import_helper( self, node: vy_ast.VyperNode, level: int, module_str: str, alias: str ) -> Any: + if module_str.startswith("vyper.interfaces"): + hint = "try renaming `vyper.interfaces` to `ethereum.ercs`" + raise ModuleNotFound(module_str, hint=hint) if _is_builtin(module_str): return _load_builtin_import(level, module_str) @@ -724,7 +734,7 @@ def _is_builtin(module_str): def _load_builtin_import(level: int, module_str: str) -> InterfaceT: if not _is_builtin(module_str): - raise ModuleNotFoundError(f"Not a builtin: {module_str}") from None + raise ModuleNotFoundError(f"Not a builtin: {module_str}") builtins_path = vyper.builtins.interfaces.__path__[0] # hygiene: convert to relpath to avoid leaking user directory info diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index c6a4531df8..96c661021f 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -3,6 +3,7 @@ from vyper import ast as vy_ast from vyper.exceptions import ( ArrayIndexException, + CompilerPanic, InstantiationException, InvalidType, StructureException, @@ -158,6 +159,12 @@ def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType: # call from_annotation to produce a better error message. typ_.from_annotation(node) + if hasattr(typ_, "module_t"): # it's a ModuleInfo + typ_ = typ_.module_t + + if not isinstance(typ_, VyperType): + raise CompilerPanic("Not a type: {typ_}", node) + return typ_ From a3bc3eb50ea10788a688ea79d74d294cd9a418d6 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 12 Feb 2024 17:07:21 -0800 Subject: [PATCH 08/12] feat: add python `sys.path` to vyper path (#3763) this makes it easier to install vyper packages from pip and import them using a regular python workflow. misc: - improve how paths appear in error messages; try hard to make them relative paths. - add `chdir_tmp_path` fixture which chdirs to the `tmp_path` fixture for the duration of the test. --- tests/conftest.py | 7 ++++ .../syntax/modules/test_initializers.py | 6 ++-- .../cli/vyper_compile/test_compile_files.py | 33 +++++++++++++++++++ vyper/cli/vyper_compile.py | 14 ++++++-- vyper/semantics/analysis/module.py | 14 +++++++- 5 files changed, 67 insertions(+), 7 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e673f17b35..6eb34a3e0a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,7 @@ from web3.contract import Contract from web3.providers.eth_tester import EthereumTesterProvider +from tests.utils import working_directory from vyper import compiler from vyper.ast.grammar import parse_vyper_source from vyper.codegen.ir_node import IRnode @@ -79,6 +80,12 @@ def debug(pytestconfig): _set_debug_mode(debug) +@pytest.fixture +def chdir_tmp_path(tmp_path): + with working_directory(tmp_path): + yield + + @pytest.fixture def keccak(): return Web3.keccak diff --git a/tests/functional/syntax/modules/test_initializers.py b/tests/functional/syntax/modules/test_initializers.py index a12f5f57ea..d0523153c8 100644 --- a/tests/functional/syntax/modules/test_initializers.py +++ b/tests/functional/syntax/modules/test_initializers.py @@ -326,7 +326,7 @@ def foo(): assert e.value._hint == "did you mean `m := lib1`?" -def test_global_initializer_constraint(make_input_bundle): +def test_global_initializer_constraint(make_input_bundle, chdir_tmp_path): lib1 = """ counter: uint256 """ @@ -818,7 +818,7 @@ def foo(new_value: uint256): assert e.value._hint == expected_hint -def test_invalid_uses(make_input_bundle): +def test_invalid_uses(make_input_bundle, chdir_tmp_path): lib1 = """ counter: uint256 """ @@ -848,7 +848,7 @@ def foo(): assert e.value._hint == "delete `uses: lib1`" -def test_invalid_uses2(make_input_bundle): +def test_invalid_uses2(make_input_bundle, chdir_tmp_path): # test a more complicated invalid uses lib1 = """ counter: uint256 diff --git a/tests/unit/cli/vyper_compile/test_compile_files.py b/tests/unit/cli/vyper_compile/test_compile_files.py index 2a65d66835..6adee24db6 100644 --- a/tests/unit/cli/vyper_compile/test_compile_files.py +++ b/tests/unit/cli/vyper_compile/test_compile_files.py @@ -1,3 +1,5 @@ +import contextlib +import sys from pathlib import Path import pytest @@ -257,3 +259,34 @@ def foo() -> uint256: contract_file = make_file("contract.vy", contract_source) assert compile_files([contract_file], ["combined_json"], paths=[tmp_path]) is not None + + +@contextlib.contextmanager +def mock_sys_path(path): + try: + sys.path.append(path) + yield + finally: + sys.path.pop() + + +def test_import_sys_path(tmp_path_factory, make_file): + library_source = """ +@internal +def foo() -> uint256: + return block.number + 1 + """ + contract_source = """ +import lib + +@external +def foo() -> uint256: + return lib.foo() + """ + tmpdir = tmp_path_factory.mktemp("test-sys-path") + with open(tmpdir / "lib.vy", "w") as f: + f.write(library_source) + + contract_file = make_file("contract.vy", contract_source) + with mock_sys_path(tmpdir): + assert compile_files([contract_file], ["combined_json"]) is not None diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index d6ba9e180a..ac69cf3310 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -238,10 +238,18 @@ def compile_files( storage_layout_paths: list[str] = None, no_bytecode_metadata: bool = False, ) -> dict: - paths = paths or [] + # lowest precedence search path is always sys path + search_paths = [Path(p) for p in sys.path] + + # python sys path uses opposite resolution order from us + # (first in list is highest precedence; we give highest precedence + # to the last in the list) + search_paths.reverse() - # lowest precedence search path is always `.` - search_paths = [Path(".")] + if Path(".") not in search_paths: + search_paths.append(Path(".")) + + paths = paths or [] for p in paths: path = Path(p).resolve(strict=True) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 9304eb3ded..43b11497ec 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -705,10 +705,22 @@ def _load_import_helper( def _parse_and_fold_ast(file: FileInput) -> vy_ast.Module: + module_path = file.resolved_path # for error messages + try: + # try to get a relative path, to simplify the error message + cwd = Path(".") + if module_path.is_absolute(): + cwd = cwd.resolve() + module_path = module_path.relative_to(cwd) + except ValueError: + # we couldn't get a relative path (cf. docs for Path.relative_to), + # use the resolved path given to us by the InputBundle + pass + ret = vy_ast.parse_to_ast( file.source_code, source_id=file.source_id, - module_path=str(file.path), + module_path=str(module_path), resolved_path=str(file.resolved_path), ) return ret From 7bdebbf12798ccda4285653f630c1a6b1d4af5b8 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 13 Feb 2024 06:14:24 -0800 Subject: [PATCH 09/12] fix: iterator modification analysis (#3764) this commit fixes several bugs with analysis of iterator modification in loops. to do so, it refactors the analysis code to track reads/writes more accurately, and uses analysis machinery instead of AST queries to perform the check. it enriches ExprInfo with an `attr` attribute, so this can be used to detect if an ExprInfo is derived from an `Attribute`. ExprInfo could be further enriched with `Subscript` info so that the Attribute/Subscript chain can be reliably recovered just from ExprInfos, especially in the future if other functions rely on being able to recover the attribute chain. this commit also modifies `validate_functions` so that it validates the functions in dependency (call graph traversal) order rather than the order they appear in the AST. refactors: - add `enter_for_loop()` context manager for convenience+clarity - remove `ExprInfo.attribute_chain`, it was too confusing - hide `ContractFunctionT` member variables (`_variable_reads`, `_variable_writes`, `_used_modules`) behind public-facing API - remove `get_root_varinfo()` in favor of a helper `_get_variable_access()` function which detects access on variable sub-members (e.g., structs). --- .../features/iteration/test_for_in_list.py | 56 ++- .../syntax/modules/test_initializers.py | 42 +++ .../unit/semantics/analysis/test_for_loop.py | 105 ++++++ vyper/ast/nodes.pyi | 8 +- vyper/codegen/expr.py | 58 ++-- vyper/semantics/analysis/base.py | 50 ++- vyper/semantics/analysis/local.py | 326 ++++++++++-------- vyper/semantics/analysis/module.py | 2 +- vyper/semantics/analysis/utils.py | 16 +- vyper/semantics/environment.py | 4 +- vyper/semantics/types/__init__.py | 2 +- vyper/semantics/types/function.py | 42 ++- vyper/semantics/types/primitives.py | 10 + 13 files changed, 505 insertions(+), 216 deletions(-) diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index 36252701c4..e1bd8f313d 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -3,6 +3,7 @@ import pytest +from vyper.compiler import compile_code from vyper.exceptions import ( ArgumentException, ImmutableViolation, @@ -841,6 +842,59 @@ def foo(): ] +# TODO: move these to tests/functional/syntax @pytest.mark.parametrize("code,err", BAD_CODE, ids=bad_code_names) def test_bad_code(assert_compile_failed, get_contract, code, err): - assert_compile_failed(lambda: get_contract(code), err) + with pytest.raises(err): + compile_code(code) + + +def test_iterator_modification_module_attribute(make_input_bundle): + # test modifying iterator via attribute + lib1 = """ +queue: DynArray[uint256, 5] + """ + main = """ +import lib1 + +initializes: lib1 + +@external +def foo(): + for i: uint256 in lib1.queue: + lib1.queue.pop() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot modify loop variable `queue`" + + +def test_iterator_modification_module_function_call(make_input_bundle): + lib1 = """ +queue: DynArray[uint256, 5] + +@internal +def popqueue(): + self.queue.pop() + """ + main = """ +import lib1 + +initializes: lib1 + +@external +def foo(): + for i: uint256 in lib1.queue: + lib1.popqueue() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot modify loop variable `queue`" diff --git a/tests/functional/syntax/modules/test_initializers.py b/tests/functional/syntax/modules/test_initializers.py index d0523153c8..d0965ae61d 100644 --- a/tests/functional/syntax/modules/test_initializers.py +++ b/tests/functional/syntax/modules/test_initializers.py @@ -741,6 +741,48 @@ def foo(new_value: uint256): assert e.value._hint == expected_hint +def test_missing_uses_subscript(make_input_bundle): + # test missing uses through nested subscript/attribute access + lib1 = """ +struct Foo: + array: uint256[5] + +foos: Foo[5] + """ + lib2 = """ +import lib1 + +counter: uint256 + +@internal +def foo(): + pass + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +# did not `use` or `initialize` lib2! + +@external +def foo(new_value: uint256): + # cannot access lib1 state through lib2 + lib2.lib1.foos[0].array[1] = new_value + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib2` state!" + + expected_hint = "add `uses: lib2` or `initializes: lib2` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + def test_missing_uses_nested_attribute_function_call(make_input_bundle): # test missing uses through nested attribute access lib1 = """ diff --git a/tests/unit/semantics/analysis/test_for_loop.py b/tests/unit/semantics/analysis/test_for_loop.py index 607587cc28..c97c9c095e 100644 --- a/tests/unit/semantics/analysis/test_for_loop.py +++ b/tests/unit/semantics/analysis/test_for_loop.py @@ -134,6 +134,111 @@ def baz(): validate_semantics(vyper_module, dummy_input_bundle) +def test_modify_iterator_recursive_function_call_topsort(dummy_input_bundle): + # test the analysis works no matter the order of functions + code = """ +a: uint256[3] + +@internal +def baz(): + for i: uint256 in self.a: + self.bar() + +@internal +def bar(): + self.foo() + +@internal +def foo(): + self.a[0] = 1 + """ + vyper_module = parse_to_ast(code) + with pytest.raises(ImmutableViolation) as e: + validate_semantics(vyper_module, dummy_input_bundle) + + assert e.value._message == "Cannot modify loop variable `a`" + + +def test_modify_iterator_through_struct(dummy_input_bundle): + # GH issue 3429 + code = """ +struct A: + iter: DynArray[uint256, 5] + +a: A + +@external +def foo(): + self.a.iter = [1, 2, 3] + for i: uint256 in self.a.iter: + self.a = A({iter: [1, 2, 3, 4]}) + """ + vyper_module = parse_to_ast(code) + with pytest.raises(ImmutableViolation) as e: + validate_semantics(vyper_module, dummy_input_bundle) + + assert e.value._message == "Cannot modify loop variable `a`" + + +def test_modify_iterator_complex_expr(dummy_input_bundle): + # GH issue 3429 + # avoid false positive! + code = """ +a: DynArray[uint256, 5] +b: uint256[10] + +@external +def foo(): + self.a = [1, 2, 3] + for i: uint256 in self.a: + self.b[self.a[1]] = i + """ + vyper_module = parse_to_ast(code) + validate_semantics(vyper_module, dummy_input_bundle) + + +def test_modify_iterator_siblings(dummy_input_bundle): + # test we can modify siblings in an access tree + code = """ +struct Foo: + a: uint256[2] + b: uint256 + +f: Foo + +@external +def foo(): + for i: uint256 in self.f.a: + self.f.b += i + """ + vyper_module = parse_to_ast(code) + validate_semantics(vyper_module, dummy_input_bundle) + + +def test_modify_subscript_barrier(dummy_input_bundle): + # test that Subscript nodes are a barrier for analysis + code = """ +struct Foo: + x: uint256[2] + y: uint256 + +struct Bar: + f: Foo[2] + +b: Bar + +@external +def foo(): + for i: uint256 in self.b.f[1].x: + self.b.f[0].y += i + """ + vyper_module = parse_to_ast(code) + with pytest.raises(ImmutableViolation) as e: + validate_semantics(vyper_module, dummy_input_bundle) + + assert e.value._message == "Cannot modify loop variable `b`" + + iterator_inference_codes = [ """ @external diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 7f863a8db9..342c84876a 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -200,13 +200,13 @@ class Call(ExprNode): class keyword(VyperNode): ... -class Attribute(VyperNode): +class Attribute(ExprNode): attr: str = ... value: ExprNode = ... -class Subscript(VyperNode): - slice: VyperNode = ... - value: VyperNode = ... +class Subscript(ExprNode): + slice: ExprNode = ... + value: ExprNode = ... class Assign(VyperNode): ... diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 335cfefb87..9c7f11dcb3 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -263,24 +263,6 @@ def parse_Attribute(self): if addr.value == "address": # for `self.code` return IRnode.from_list(["~selfcode"], typ=BytesT(0)) return IRnode.from_list(["~extcode", addr], typ=BytesT(0)) - # self.x: global attribute - elif (varinfo := self.expr._expr_info.var_info) is not None: - if varinfo.is_constant: - return Expr.parse_value_expr(varinfo.decl_node.value, self.context) - - location = data_location_to_address_space( - varinfo.location, self.context.is_ctor_context - ) - - ret = IRnode.from_list( - varinfo.position.position, - typ=varinfo.typ, - location=location, - annotation="self." + self.expr.attr, - ) - ret._referenced_variables = {varinfo} - - return ret # Reserved keywords elif ( @@ -336,17 +318,37 @@ def parse_Attribute(self): "chain.id is unavailable prior to istanbul ruleset", self.expr ) return IRnode.from_list(["chainid"], typ=UINT256_T) + # Other variables - else: - sub = Expr(self.expr.value, self.context).ir_node - # contract type - if isinstance(sub.typ, InterfaceT): - # MyInterface.address - assert self.expr.attr == "address" - sub.typ = typ - return sub - if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types: - return get_element_ptr(sub, self.expr.attr) + + # self.x: global attribute + if (varinfo := self.expr._expr_info.var_info) is not None: + if varinfo.is_constant: + return Expr.parse_value_expr(varinfo.decl_node.value, self.context) + + location = data_location_to_address_space( + varinfo.location, self.context.is_ctor_context + ) + + ret = IRnode.from_list( + varinfo.position.position, + typ=varinfo.typ, + location=location, + annotation="self." + self.expr.attr, + ) + ret._referenced_variables = {varinfo} + + return ret + + sub = Expr(self.expr.value, self.context).ir_node + # contract type + if isinstance(sub.typ, InterfaceT): + # MyInterface.address + assert self.expr.attr == "address" + sub.typ = typ + return sub + if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types: + return get_element_ptr(sub, self.expr.attr) def parse_Subscript(self): sub = Expr(self.expr.value, self.context).ir_node diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 2086e5f9da..49b867aae5 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -1,5 +1,5 @@ import enum -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, Optional, Union from vyper import ast as vy_ast @@ -193,6 +193,17 @@ def is_constant(self): return res +@dataclass(frozen=True) +class VarAccess: + variable: VarInfo + attrs: tuple[str, ...] + + def contains(self, other): + # VarAccess("v", ("a")) `contains` VarAccess("v", ("a", "b", "c")) + sub_attrs = other.attrs[: len(self.attrs)] + return self.variable == other.variable and sub_attrs == self.attrs + + @dataclass class ExprInfo: """ @@ -204,9 +215,7 @@ class ExprInfo: module_info: Optional[ModuleInfo] = None location: DataLocation = DataLocation.UNSET modifiability: Modifiability = Modifiability.MODIFIABLE - - # the chain of attribute parents for this expr - attribute_chain: list["ExprInfo"] = field(default_factory=list) + attr: Optional[str] = None def __post_init__(self): should_match = ("typ", "location", "modifiability") @@ -215,48 +224,35 @@ def __post_init__(self): if getattr(self.var_info, attr) != getattr(self, attr): raise CompilerPanic("Bad analysis: non-matching {attr}: {self}") - self._writes: OrderedSet[VarInfo] = OrderedSet() - self._reads: OrderedSet[VarInfo] = OrderedSet() - - # find exprinfo in the attribute chain which has a varinfo - # e.x. `x` will return varinfo for `x` - # `module.foo` will return varinfo for `module.foo` - # `self.my_struct.x.y` will return varinfo for `self.my_struct` - def get_root_varinfo(self) -> Optional[VarInfo]: - for expr_info in self.attribute_chain + [self]: - if expr_info.var_info is not None: - return expr_info.var_info - return None + self._writes: OrderedSet[VarAccess] = OrderedSet() + self._reads: OrderedSet[VarAccess] = OrderedSet() @classmethod - def from_varinfo(cls, var_info: VarInfo, attribute_chain=None) -> "ExprInfo": + def from_varinfo(cls, var_info: VarInfo, **kwargs) -> "ExprInfo": return cls( var_info.typ, var_info=var_info, location=var_info.location, modifiability=var_info.modifiability, - attribute_chain=attribute_chain or [], + **kwargs, ) @classmethod - def from_moduleinfo(cls, module_info: ModuleInfo, attribute_chain=None) -> "ExprInfo": + def from_moduleinfo(cls, module_info: ModuleInfo, **kwargs) -> "ExprInfo": modifiability = Modifiability.RUNTIME_CONSTANT if module_info.ownership >= ModuleOwnership.USES: modifiability = Modifiability.MODIFIABLE return cls( - module_info.module_t, - module_info=module_info, - modifiability=modifiability, - attribute_chain=attribute_chain or [], + module_info.module_t, module_info=module_info, modifiability=modifiability, **kwargs ) - def copy_with_type(self, typ: VyperType, attribute_chain=None) -> "ExprInfo": + def copy_with_type(self, typ: VyperType, **kwargs) -> "ExprInfo": """ Return a copy of the ExprInfo but with the type set to something else """ to_copy = ("location", "modifiability") fields = {k: getattr(self, k) for k in to_copy} - if attribute_chain is not None: - fields["attribute_chain"] = attribute_chain - return self.__class__(typ=typ, **fields) + for t in to_copy: + assert t not in kwargs + return self.__class__(typ=typ, **fields, **kwargs) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index d96215ede0..39a1c59290 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -1,5 +1,6 @@ # CMC 2024-02-03 TODO: split me into function.py and expr.py +import contextlib from typing import Optional from vyper import ast as vy_ast @@ -19,7 +20,13 @@ VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import Modifiability, ModuleOwnership, VarInfo +from vyper.semantics.analysis.base import ( + Modifiability, + ModuleInfo, + ModuleOwnership, + VarAccess, + VarInfo, +) from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.utils import ( get_common_types, @@ -58,18 +65,33 @@ def validate_functions(vy_module: vy_ast.Module) -> None: """Analyzes a vyper ast and validates the function bodies""" err_list = ExceptionList() - namespace = get_namespace() + for node in vy_module.get_children(vy_ast.FunctionDef): - with namespace.enter_scope(): - try: - analyzer = FunctionAnalyzer(vy_module, node, namespace) - analyzer.analyze() - except VyperException as e: - err_list.append(e) + _validate_function_r(vy_module, node, err_list) err_list.raise_if_not_empty() +def _validate_function_r( + vy_module: vy_ast.Module, node: vy_ast.FunctionDef, err_list: ExceptionList +): + func_t = node._metadata["func_type"] + + for call_t in func_t.called_functions: + if isinstance(call_t, ContractFunctionT): + assert isinstance(call_t.ast_def, vy_ast.FunctionDef) # help mypy + _validate_function_r(vy_module, call_t.ast_def, err_list) + + namespace = get_namespace() + + try: + with namespace.enter_scope(): + analyzer = FunctionAnalyzer(vy_module, node, namespace) + analyzer.analyze() + except VyperException as e: + err_list.append(e) + + # finds the terminus node for a list of nodes. # raises an exception if any nodes are unreachable def find_terminating_node(node_list: list) -> Optional[vy_ast.VyperNode]: @@ -99,36 +121,6 @@ def find_terminating_node(node_list: list) -> Optional[vy_ast.VyperNode]: return ret -def _check_iterator_modification( - target_node: vy_ast.VyperNode, search_node: vy_ast.VyperNode -) -> Optional[vy_ast.VyperNode]: - similar_nodes = [ - n - for n in search_node.get_descendants(type(target_node)) - if vy_ast.compare_nodes(target_node, n) - ] - - for node in similar_nodes: - # raise if the node is the target of an assignment statement - assign_node = node.get_ancestor((vy_ast.Assign, vy_ast.AugAssign)) - # note the use of get_descendants() blocks statements like - # self.my_array[i] = x - if assign_node and node in assign_node.target.get_descendants(include_self=True): - return node - - attr_node = node.get_ancestor(vy_ast.Attribute) - # note the use of get_descendants() blocks statements like - # self.my_array[i].append(x) - if ( - attr_node is not None - and node in attr_node.value.get_descendants(include_self=True) - and attr_node.attr in ("append", "pop", "extend") - ): - return node - - return None - - # helpers def _validate_address_code(node: vy_ast.Attribute, value_type: VyperType) -> None: if isinstance(value_type, AddressT) and node.attr == "code": @@ -183,6 +175,62 @@ def _validate_self_reference(node: vy_ast.Name) -> None: raise StateAccessViolation("not allowed to query self in pure functions", node) +# analyse the variable access for the attribute chain for a node +# e.x. `x` will return varinfo for `x` +# `module.foo` will return VarAccess for `module.foo` +# `self.my_struct.x.y` will return VarAccess for `self.my_struct.x.y` +def _get_variable_access(node: vy_ast.ExprNode) -> Optional[VarAccess]: + attrs: list[str] = [] + info = get_expr_info(node) + + while info.var_info is None: + if not isinstance(node, (vy_ast.Subscript, vy_ast.Attribute)): + # it's something like a literal + return None + + if isinstance(node, vy_ast.Subscript): + # Subscript is an analysis barrier + # we cannot analyse if `x.y[ix1].z` overlaps with `x.y[ix2].z`. + attrs.clear() + + if (attr := info.attr) is not None: + attrs.append(attr) + + assert isinstance(node, (vy_ast.Subscript, vy_ast.Attribute)) # help mypy + node = node.value + info = get_expr_info(node) + + # ignore `self.` as it interferes with VarAccess comparison across modules + if len(attrs) > 0 and attrs[-1] == "self": + attrs.pop() + attrs.reverse() + + return VarAccess(info.var_info, tuple(attrs)) + + +# get the chain of modules, e.g. +# mod1.mod2.x.y -> [ModuleInfo(mod1), ModuleInfo(mod2)] +# CMC 2024-02-12 note that the Attribute/Subscript traversal in this and +# _get_variable_access() are a bit gross and could probably +# be refactored into data on ExprInfo. +def _get_module_chain(node: vy_ast.ExprNode) -> list[ModuleInfo]: + ret: list[ModuleInfo] = [] + info = get_expr_info(node) + + while True: + if info.module_info is not None: + ret.append(info.module_info) + + if not isinstance(node, (vy_ast.Subscript, vy_ast.Attribute)): + break + + node = node.value + info = get_expr_info(node) + + ret.reverse() + return ret + + class FunctionAnalyzer(VyperNodeVisitorBase): ignored_types = (vy_ast.Pass,) scope_name = "function" @@ -196,7 +244,16 @@ def __init__( self.func = fn_node._metadata["func_type"] self.expr_visitor = ExprVisitor(self) + self.loop_variables: list[Optional[VarAccess]] = [] + def analyze(self): + if self.func.analysed: + return + + # mark seen before analysing, if analysis throws an exception which + # gets caught, we don't want to analyse again. + self.func.mark_analysed() + # allow internal function params to be mutable if self.func.is_internal: location, modifiability = (DataLocation.MEMORY, Modifiability.MODIFIABLE) @@ -225,6 +282,14 @@ def analyze(self): for kwarg in self.func.keyword_args: self.expr_visitor.visit(kwarg.default_value, kwarg.typ) + @contextlib.contextmanager + def enter_for_loop(self, varaccess: Optional[VarAccess]): + self.loop_variables.append(varaccess) + try: + yield + finally: + self.loop_variables.pop() + def visit(self, node): super().visit(node) @@ -326,16 +391,13 @@ def _handle_modification(self, target: vy_ast.ExprNode): if info.modifiability == Modifiability.CONSTANT: raise ImmutableViolation("Constant value cannot be written to.") - var_info = info.get_root_varinfo() - assert var_info is not None + var_access = _get_variable_access(target) + assert var_access is not None - info._writes.add(var_info) + info._writes.add(var_access) def _check_module_use(self, target: vy_ast.ExprNode): - module_infos = [] - for t in get_expr_info(target).attribute_chain: - if t.module_info is not None: - module_infos.append(t.module_info) + module_infos = _get_module_chain(target) if len(module_infos) == 0: return @@ -352,7 +414,7 @@ def _check_module_use(self, target: vy_ast.ExprNode): root_module_info = module_infos[0] # log the access - self.func._used_modules.add(root_module_info) + self.func.mark_used_module(root_module_info) def visit_Assign(self, node): self._assign_helper(node) @@ -403,96 +465,68 @@ def visit_Expr(self, node): ) self.expr_visitor.visit(node.value, return_value) + def _analyse_range_iter(self, iter_node, target_type): + # iteration via range() + if iter_node.get("func.id") != "range": + raise IteratorException("Cannot iterate over the result of a function call", iter_node) + _validate_range_call(iter_node) + + args = iter_node.args + kwargs = [s.value for s in iter_node.keywords] + for arg in (*args, *kwargs): + self.expr_visitor.visit(arg, target_type) + + def _analyse_list_iter(self, iter_node, target_type): + # iteration over a variable or literal list + iter_val = iter_node + if iter_val.has_folded_value: + iter_val = iter_val.get_folded_value() + + if isinstance(iter_val, vy_ast.List): + len_ = len(iter_val.elements) + if len_ == 0: + raise StructureException("For loop must have at least 1 iteration", iter_node) + iter_type = SArrayT(target_type, len_) + else: + try: + iter_type = get_exact_type_from_node(iter_node) + except (InvalidType, StructureException): + raise InvalidType("Not an iterable type", iter_node) + + # CMC 2024-02-09 TODO: use validate_expected_type once we have DArrays + # with generic length. + if not isinstance(iter_type, (DArrayT, SArrayT)): + raise InvalidType("Not an iterable type", iter_node) + + self.expr_visitor.visit(iter_node, iter_type) + + # get the root varinfo from iter_val in case we need to peer + # through folded constants + return _get_variable_access(iter_val) + def visit_For(self, node): if not isinstance(node.target.target, vy_ast.Name): raise StructureException("Invalid syntax for loop iterator", node.target.target) target_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY) + iter_var = None if isinstance(node.iter, vy_ast.Call): - # iteration via range() - if node.iter.get("func.id") != "range": - raise IteratorException( - "Cannot iterate over the result of a function call", node.iter - ) - _validate_range_call(node.iter) - + self._analyse_range_iter(node.iter, target_type) else: - # iteration over a variable or literal list - iter_val = node.iter.get_folded_value() if node.iter.has_folded_value else node.iter - if isinstance(iter_val, vy_ast.List) and len(iter_val.elements) == 0: - raise StructureException("For loop must have at least 1 iteration", node.iter) - - if not any( - isinstance(i, (DArrayT, SArrayT)) for i in get_possible_types_from_node(node.iter) - ): - raise InvalidType("Not an iterable type", node.iter) - - if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): - # check for references to the iterated value within the body of the loop - assign = _check_iterator_modification(node.iter, node) - if assign: - raise ImmutableViolation("Cannot modify array during iteration", assign) - - # Check if `iter` is a storage variable. get_descendants` is used to check for - # nested `self` (e.g. structs) - # NOTE: this analysis will be borked once stateful modules are allowed! - iter_is_storage_var = ( - isinstance(node.iter, vy_ast.Attribute) - and len(node.iter.get_descendants(vy_ast.Name, {"id": "self"})) > 0 - ) - - if iter_is_storage_var: - # check if iterated value may be modified by function calls inside the loop - iter_name = node.iter.attr - for call_node in node.get_descendants(vy_ast.Call, {"func.value.id": "self"}): - fn_name = call_node.func.attr - - fn_node = self.vyper_module.get_children(vy_ast.FunctionDef, {"name": fn_name})[0] - if _check_iterator_modification(node.iter, fn_node): - # check for direct modification - raise ImmutableViolation( - f"Cannot call '{fn_name}' inside for loop, it potentially " - f"modifies iterated storage variable '{iter_name}'", - call_node, - ) + iter_var = self._analyse_list_iter(node.iter, target_type) - for reachable_t in ( - self.namespace["self"].typ.members[fn_name].reachable_internal_functions - ): - # check for indirect modification - name = reachable_t.name - fn_node = self.vyper_module.get_children(vy_ast.FunctionDef, {"name": name})[0] - if _check_iterator_modification(node.iter, fn_node): - raise ImmutableViolation( - f"Cannot call '{fn_name}' inside for loop, it may call to '{name}' " - f"which potentially modifies iterated storage variable '{iter_name}'", - call_node, - ) - - target_name = node.target.target.id - with self.namespace.enter_scope(): + with self.namespace.enter_scope(), self.enter_for_loop(iter_var): + target_name = node.target.target.id + # maybe we should introduce a new Modifiability: LOOP_VARIABLE self.namespace[target_name] = VarInfo( target_type, modifiability=Modifiability.RUNTIME_CONSTANT ) + self.expr_visitor.visit(node.target.target, target_type) for stmt in node.body: self.visit(stmt) - self.expr_visitor.visit(node.target.target, target_type) - - if isinstance(node.iter, vy_ast.List): - len_ = len(node.iter.elements) - self.expr_visitor.visit(node.iter, SArrayT(target_type, len_)) - elif isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": - args = node.iter.args - kwargs = [s.value for s in node.iter.keywords] - for arg in (*args, *kwargs): - self.expr_visitor.visit(arg, target_type) - else: - iter_type = get_exact_type_from_node(node.iter) - self.expr_visitor.visit(node.iter, iter_type) - def visit_If(self, node): self.expr_visitor.visit(node.test, BoolT()) with self.namespace.enter_scope(): @@ -577,18 +611,32 @@ def visit(self, node, typ): # log variable accesses. # (note writes will get logged as both read+write) - varinfo = info.var_info - if varinfo is not None: - info._reads.add(varinfo) + var_access = _get_variable_access(node) + if var_access is not None: + info._reads.add(var_access) + + if self.function_analyzer: + for s in self.function_analyzer.loop_variables: + if s is None: + continue + + for v in info._writes: + if not v.contains(s): + continue + + msg = "Cannot modify loop variable" + var = s.variable + if var.decl_node is not None: + msg += f" `{var.decl_node.target.id}`" + raise ImmutableViolation(msg, var.decl_node, node) - if self.func: variable_accesses = info._writes | info._reads for s in variable_accesses: - if s.is_module_variable(): + if s.variable.is_module_variable(): self.function_analyzer._check_module_use(node) - self.func._variable_writes.update(info._writes) - self.func._variable_reads.update(info._reads) + self.func.mark_variable_writes(info._writes) + self.func.mark_variable_reads(info._reads) # validate and annotate folded value if node.has_folded_value: @@ -641,24 +689,23 @@ def _check_call_mutability(self, call_mutability: StateMutability): def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: func_info = get_expr_info(node.func, is_callable=True) func_type = func_info.typ - self.visit(node.func, func_type) if isinstance(func_type, ContractFunctionT): # function calls - func_info._writes.update(func_type._variable_writes) - func_info._reads.update(func_type._variable_reads) + if not func_type.from_interface: + for s in func_type.get_variable_writes(): + if s.variable.is_module_variable(): + func_info._writes.add(s) + for s in func_type.get_variable_reads(): + if s.variable.is_module_variable(): + func_info._reads.add(s) if self.function_analyzer: - if func_type.is_internal: - self.func.called_functions.add(func_type) - self._check_call_mutability(func_type.mutability) - # check that if the function accesses state, the defining - # module has been `used` or `initialized`. - for s in func_type._variable_accesses: - if s.is_module_variable(): + for s in func_type.get_variable_accesses(): + if s.variable.is_module_variable(): self.function_analyzer._check_module_use(node.func) if func_type.is_deploy and not self.func.is_deploy: @@ -689,7 +736,8 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: elif isinstance(func_type, MemberFunctionT): if func_type.is_modifying and self.function_analyzer is not None: # TODO refactor this - self.function_analyzer._handle_modification(node.func) + assert isinstance(node.func, vy_ast.Attribute) # help mypy + self.function_analyzer._handle_modification(node.func.value) assert len(node.args) == len(func_type.arg_types) for arg, arg_type in zip(node.args, func_type.arg_types): self.visit(arg, arg_type) @@ -702,6 +750,8 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: for kwarg in node.keywords: self.visit(kwarg.value, kwarg_types[kwarg.arg]) + self.visit(node.func, func_type) + def visit_Compare(self, node: vy_ast.Compare, typ: VyperType) -> None: if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)): # membership in list literal - `x in [a, b, c]` diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 43b11497ec..10acef59da 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -244,7 +244,7 @@ def validate_used_modules(self): all_used_modules = OrderedSet() for f in module_t.functions.values(): - for u in f._used_modules: + for u in f.get_used_modules(): all_used_modules.add(u.module_t) for used_module in all_used_modules: diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index f1f0f48a86..034cd8c46e 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -84,28 +84,24 @@ def get_expr_info(self, node: vy_ast.VyperNode, is_callable: bool = False) -> Ex # propagate the parent exprinfo members down into the new expr # note: Attribute(expr value, identifier attr) - name = node.attr info = self.get_expr_info(node.value, is_callable=is_callable) + attr = node.attr - attribute_chain = info.attribute_chain + [info] - - t = info.typ.get_member(name, node) + t = info.typ.get_member(attr, node) # it's a top-level variable if isinstance(t, VarInfo): - return ExprInfo.from_varinfo(t, attribute_chain=attribute_chain) + return ExprInfo.from_varinfo(t, attr=attr) if isinstance(t, ModuleInfo): - return ExprInfo.from_moduleinfo(t, attribute_chain=attribute_chain) + return ExprInfo.from_moduleinfo(t, attr=attr) - # it's something else, like my_struct.foo - return info.copy_with_type(t, attribute_chain=attribute_chain) + return info.copy_with_type(t, attr=attr) # If it's a Subscript, propagate the subscriptable varinfo if isinstance(node, vy_ast.Subscript): info = self.get_expr_info(node.value) - attribute_chain = info.attribute_chain + [info] - return info.copy_with_type(t, attribute_chain=attribute_chain) + return info.copy_with_type(t) return ExprInfo(t) diff --git a/vyper/semantics/environment.py b/vyper/semantics/environment.py index 38bac0a63d..94a26157af 100644 --- a/vyper/semantics/environment.py +++ b/vyper/semantics/environment.py @@ -1,7 +1,7 @@ from typing import Dict from vyper.semantics.analysis.base import Modifiability, VarInfo -from vyper.semantics.types import AddressT, BytesT, VyperType +from vyper.semantics.types import AddressT, BytesT, SelfT, VyperType from vyper.semantics.types.shortcuts import BYTES32_T, UINT256_T @@ -57,7 +57,7 @@ def get_constant_vars() -> Dict: return result -MUTABLE_ENVIRONMENT_VARS: Dict[str, type] = {"self": AddressT} +MUTABLE_ENVIRONMENT_VARS: Dict[str, type] = {"self": SelfT} def get_mutable_vars() -> Dict: diff --git a/vyper/semantics/types/__init__.py b/vyper/semantics/types/__init__.py index a04632b96f..59a20dd99f 100644 --- a/vyper/semantics/types/__init__.py +++ b/vyper/semantics/types/__init__.py @@ -3,7 +3,7 @@ from .bytestrings import BytesT, StringT, _BytestringT from .function import MemberFunctionT from .module import InterfaceT -from .primitives import AddressT, BoolT, BytesM_T, DecimalT, IntegerT +from .primitives import AddressT, BoolT, BytesM_T, DecimalT, IntegerT, SelfT from .subscriptable import DArrayT, HashMapT, SArrayT, TupleT from .user import EventT, FlagT, StructT diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 62f9c60585..705470a798 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -21,7 +21,7 @@ Modifiability, ModuleInfo, StateMutability, - VarInfo, + VarAccess, VarOffset, ) from vyper.semantics.analysis.utils import ( @@ -92,6 +92,7 @@ def __init__( return_type: Optional[VyperType], function_visibility: FunctionVisibility, state_mutability: StateMutability, + from_interface: bool = False, nonreentrant: Optional[str] = None, ast_def: Optional[vy_ast.VyperNode] = None, ) -> None: @@ -104,9 +105,12 @@ def __init__( self.visibility = function_visibility self.mutability = state_mutability self.nonreentrant = nonreentrant + self.from_interface = from_interface self.ast_def = ast_def + self._analysed = False + # a list of internal functions this function calls. # to be populated during analysis self.called_functions: OrderedSet[ContractFunctionT] = OrderedSet() @@ -115,10 +119,10 @@ def __init__( self.reachable_internal_functions: OrderedSet[ContractFunctionT] = OrderedSet() # writes to variables from this function - self._variable_writes: OrderedSet[VarInfo] = OrderedSet() + self._variable_writes: OrderedSet[VarAccess] = OrderedSet() # reads of variables from this function - self._variable_reads: OrderedSet[VarInfo] = OrderedSet() + self._variable_reads: OrderedSet[VarAccess] = OrderedSet() # list of modules used (accessed state) by this function self._used_modules: OrderedSet[ModuleInfo] = OrderedSet() @@ -127,10 +131,35 @@ def __init__( self._ir_info: Any = None self._function_id: Optional[int] = None + def mark_analysed(self): + assert not self._analysed + self._analysed = True + @property - def _variable_accesses(self): + def analysed(self): + return self._analysed + + def get_variable_reads(self): + return self._variable_reads + + def get_variable_writes(self): + return self._variable_writes + + def get_variable_accesses(self): return self._variable_reads | self._variable_writes + def get_used_modules(self): + return self._used_modules + + def mark_used_module(self, module_info): + self._used_modules.add(module_info) + + def mark_variable_writes(self, var_infos): + self._variable_writes.update(var_infos) + + def mark_variable_reads(self, var_infos): + self._variable_reads.update(var_infos) + @property def modifiability(self): return Modifiability.from_state_mutability(self.mutability) @@ -189,6 +218,7 @@ def from_abi(cls, abi: dict) -> "ContractFunctionT": positional_args, [], return_type, + from_interface=True, function_visibility=FunctionVisibility.EXTERNAL, state_mutability=StateMutability.from_abi(abi), ) @@ -248,6 +278,7 @@ def from_InterfaceDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": return_type, function_visibility, state_mutability, + from_interface=True, nonreentrant=None, ast_def=funcdef, ) @@ -300,6 +331,7 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": return_type, function_visibility, state_mutability, + from_interface=True, nonreentrant=nonreentrant_key, ast_def=funcdef, ) @@ -370,6 +402,7 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": return_type, function_visibility, state_mutability, + from_interface=False, nonreentrant=nonreentrant_key, ast_def=funcdef, ) @@ -410,6 +443,7 @@ def getter_from_VariableDecl(cls, node: vy_ast.VariableDecl) -> "ContractFunctio args, [], return_type, + from_interface=False, function_visibility=FunctionVisibility.EXTERNAL, state_mutability=StateMutability.VIEW, ast_def=node, diff --git a/vyper/semantics/types/primitives.py b/vyper/semantics/types/primitives.py index 07d1a21a94..d383f72ab2 100644 --- a/vyper/semantics/types/primitives.py +++ b/vyper/semantics/types/primitives.py @@ -340,3 +340,13 @@ def validate_literal(self, node: vy_ast.Constant) -> None: f"address, the correct checksummed form is: {checksum_encode(addr)}", node, ) + + +# type for "self" +# refactoring note: it might be best for this to be a ModuleT actually +class SelfT(AddressT): + _id = "self" + + def compare_type(self, other): + # compares true to AddressT + return isinstance(other, type(self)) or isinstance(self, type(other)) From 199f2b65e43e3d3f055756039ef4a9bce7f6f3cf Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 13 Feb 2024 06:56:41 -0800 Subject: [PATCH 10/12] feat[lang]: remove named reentrancy locks (#3769) this commit removes "fine-grained" nonreentrancy locks (i.e., reentrancy locks with names) from vyper. they aren't really used (all known production contracts just use a single global named lock) , and in any case such a use case should better be implemented manually by the user. this simplifies the language and allows moderate simplification to the storage allocator, although some complexity is added because the global restriction has to have special handling (it cannot be handled simply in the recursion into child modules). refactors: - the routine for allocating nonreentrant keys has been refactored into a helper function. --- docs/control-structures.rst | 12 +- .../features/decorators/test_nonreentrant.py | 139 ++++++++++++++---- .../exceptions/test_structure_exception.py | 31 ---- .../test_invalid_function_decorators.py | 15 +- .../cli/storage_layout/test_storage_layout.py | 75 ++++++---- .../test_storage_layout_overrides.py | 34 ++++- tests/unit/semantics/test_storage_slots.py | 11 +- vyper/semantics/analysis/data_positions.py | 83 ++++++----- vyper/semantics/types/function.py | 70 ++++----- 9 files changed, 291 insertions(+), 179 deletions(-) diff --git a/docs/control-structures.rst b/docs/control-structures.rst index a0aa927261..4e18a21bd8 100644 --- a/docs/control-structures.rst +++ b/docs/control-structures.rst @@ -100,22 +100,24 @@ Functions marked with ``@pure`` cannot call non-``pure`` functions. Re-entrancy Locks ----------------- -The ``@nonreentrant()`` decorator places a lock on a function, and all functions with the same ```` value. An attempt by an external contract to call back into any of these functions causes the transaction to revert. +The ``@nonreentrant`` decorator places a global nonreentrancy lock on a function. An attempt by an external contract to call back into any other ``@nonreentrant`` function causes the transaction to revert. .. code-block:: vyper @external - @nonreentrant("lock") + @nonreentrant def make_a_call(_addr: address): # this function is protected from re-entrancy ... -You can put the ``@nonreentrant()`` decorator on a ``__default__`` function but we recommend against it because in most circumstances it will not work in a meaningful way. +You can put the ``@nonreentrant`` decorator on a ``__default__`` function but we recommend against it because in most circumstances it will not work in a meaningful way. Nonreentrancy locks work by setting a specially allocated storage slot to a ```` value on function entrance, and setting it to an ```` value on function exit. On function entrance, if the storage slot is detected to be the ```` value, execution reverts. You cannot put the ``@nonreentrant`` decorator on a ``pure`` function. You can put it on a ``view`` function, but it only checks that the function is not in a callback (the storage slot is not in the ```` state), as ``view`` functions can only read the state, not change it. +You can view where the nonreentrant key is physically laid out in storage by using ``vyper`` with the ``-f layout`` option (e.g., ``vyper -f layout foo.vy``). Unless it is overriden, the compiler will allocate it at slot ``0``. + .. note:: A mutable function can protect a ``view`` function from being called back into (which is useful for instance, if a ``view`` function would return inconsistent state during a mutable function), but a ``view`` function cannot protect itself from being called back into. Note that mutable functions can never be called from a ``view`` function because all external calls out from a ``view`` function are protected by the use of the ``STATICCALL`` opcode. @@ -123,6 +125,8 @@ You cannot put the ``@nonreentrant`` decorator on a ``pure`` function. You can p A nonreentrant lock has an ```` value of 3, and a ```` value of 2. Nonzero values are used to take advantage of net gas metering - as of the Berlin hard fork, the net cost for utilizing a nonreentrant lock is 2300 gas. Prior to v0.3.4, the ```` and ```` values were 0 and 1, respectively. +.. note:: + Prior to 0.4.0, nonreentrancy keys took a "key" argument for fine-grained nonreentrancy control. As of 0.4.0, only a global nonreentrancy lock is available. The ``__default__`` Function ---------------------------- @@ -194,7 +198,7 @@ Decorator Description ``@pure`` Function does not read contract state or environment variables ``@view`` Function does not alter contract state ``@payable`` Function is able to receive Ether -``@nonreentrant()`` Function cannot be called back into during an external call +``@nonreentrant`` Function cannot be called back into during an external call =============================== =========================================================== ``if`` statements diff --git a/tests/functional/codegen/features/decorators/test_nonreentrant.py b/tests/functional/codegen/features/decorators/test_nonreentrant.py index 9329605678..92a21cd302 100644 --- a/tests/functional/codegen/features/decorators/test_nonreentrant.py +++ b/tests/functional/codegen/features/decorators/test_nonreentrant.py @@ -2,30 +2,103 @@ from vyper.exceptions import FunctionDeclarationException - # TODO test functions in this module across all evm versions # once we have cancun support. + + def test_nonreentrant_decorator(get_contract, tx_failed): - calling_contract_code = """ -interface SpecialContract: + malicious_code = """ +interface ProtectedContract: + def protected_function(callback_address: address): nonpayable + +@external +def do_callback(): + ProtectedContract(msg.sender).protected_function(self) + """ + + protected_code = """ +interface Callbackable: + def do_callback(): nonpayable + +@external +@nonreentrant +def protected_function(c: Callbackable): + c.do_callback() + +# add a default function so we know the callback didn't fail for any reason +# besides nonreentrancy +@external +def __default__(): + pass + """ + contract = get_contract(protected_code) + malicious = get_contract(malicious_code) + + with tx_failed(): + contract.protected_function(malicious.address) + + +def test_nonreentrant_view_function(get_contract, tx_failed): + malicious_code = """ +interface ProtectedContract: + def protected_function(): nonpayable + def protected_view_fn() -> uint256: view + +@external +def do_callback() -> uint256: + return ProtectedContract(msg.sender).protected_view_fn() + """ + + protected_code = """ +interface Callbackable: + def do_callback(): nonpayable + +@external +@nonreentrant +def protected_function(c: Callbackable): + c.do_callback() + +@external +@nonreentrant +@view +def protected_view_fn() -> uint256: + return 10 + +# add a default function so we know the callback didn't fail for any reason +# besides nonreentrancy +@external +def __default__(): + pass + """ + contract = get_contract(protected_code) + malicious = get_contract(malicious_code) + + with tx_failed(): + contract.protected_function(malicious.address) + + +def test_multi_function_nonreentrant(get_contract, tx_failed): + malicious_code = """ +interface ProtectedContract: def unprotected_function(val: String[100], do_callback: bool): nonpayable def protected_function(val: String[100], do_callback: bool): nonpayable def special_value() -> String[100]: nonpayable @external def updated(): - SpecialContract(msg.sender).unprotected_function('surprise!', False) + ProtectedContract(msg.sender).unprotected_function('surprise!', False) @external def updated_protected(): # This should fail. - SpecialContract(msg.sender).protected_function('surprise protected!', False) + ProtectedContract(msg.sender).protected_function('surprise protected!', False) """ - reentrant_code = """ + protected_code = """ interface Callback: def updated(): nonpayable def updated_protected(): nonpayable + interface Self: def protected_function(val: String[100], do_callback: bool) -> uint256: nonpayable def protected_function2(val: String[100], do_callback: bool) -> uint256: nonpayable @@ -39,7 +112,7 @@ def set_callback(c: address): self.callback = Callback(c) @external -@nonreentrant('protect_special_value') +@nonreentrant def protected_function(val: String[100], do_callback: bool) -> uint256: self.special_value = val @@ -50,7 +123,7 @@ def protected_function(val: String[100], do_callback: bool) -> uint256: return 2 @external -@nonreentrant('protect_special_value') +@nonreentrant def protected_function2(val: String[100], do_callback: bool) -> uint256: self.special_value = val if do_callback: @@ -60,7 +133,7 @@ def protected_function2(val: String[100], do_callback: bool) -> uint256: return 2 @external -@nonreentrant('protect_special_value') +@nonreentrant def protected_function3(val: String[100], do_callback: bool) -> uint256: self.special_value = val if do_callback: @@ -71,7 +144,8 @@ def protected_function3(val: String[100], do_callback: bool) -> uint256: @external -@nonreentrant('protect_special_value') +@nonreentrant +@view def protected_view_fn() -> String[100]: return self.special_value @@ -81,37 +155,42 @@ def unprotected_function(val: String[100], do_callback: bool): if do_callback: self.callback.updated() - """ - reentrant_contract = get_contract(reentrant_code) - calling_contract = get_contract(calling_contract_code) +# add a default function so we know the callback didn't fail for any reason +# besides nonreentrancy +@external +def __default__(): + pass + """ + contract = get_contract(protected_code) + malicious = get_contract(malicious_code) - reentrant_contract.set_callback(calling_contract.address, transact={}) - assert reentrant_contract.callback() == calling_contract.address + contract.set_callback(malicious.address, transact={}) + assert contract.callback() == malicious.address # Test unprotected function. - reentrant_contract.unprotected_function("some value", True, transact={}) - assert reentrant_contract.special_value() == "surprise!" + contract.unprotected_function("some value", True, transact={}) + assert contract.special_value() == "surprise!" # Test protected function. - reentrant_contract.protected_function("some value", False, transact={}) - assert reentrant_contract.special_value() == "some value" - assert reentrant_contract.protected_view_fn() == "some value" + contract.protected_function("some value", False, transact={}) + assert contract.special_value() == "some value" + assert contract.protected_view_fn() == "some value" with tx_failed(): - reentrant_contract.protected_function("zzz value", True, transact={}) + contract.protected_function("zzz value", True, transact={}) - reentrant_contract.protected_function2("another value", False, transact={}) - assert reentrant_contract.special_value() == "another value" + contract.protected_function2("another value", False, transact={}) + assert contract.special_value() == "another value" with tx_failed(): - reentrant_contract.protected_function2("zzz value", True, transact={}) + contract.protected_function2("zzz value", True, transact={}) - reentrant_contract.protected_function3("another value", False, transact={}) - assert reentrant_contract.special_value() == "another value" + contract.protected_function3("another value", False, transact={}) + assert contract.special_value() == "another value" with tx_failed(): - reentrant_contract.protected_function3("zzz value", True, transact={}) + contract.protected_function3("zzz value", True, transact={}) def test_nonreentrant_decorator_for_default(w3, get_contract, tx_failed): @@ -145,7 +224,7 @@ def set_callback(c: address): @external @payable -@nonreentrant("lock") +@nonreentrant def protected_function(val: String[100], do_callback: bool) -> uint256: self.special_value = val _amount: uint256 = msg.value @@ -169,7 +248,7 @@ def unprotected_function(val: String[100], do_callback: bool): @external @payable -@nonreentrant("lock") +@nonreentrant def __default__(): pass """ @@ -209,7 +288,7 @@ def test_disallow_on_init_function(get_contract): code = """ @external -@nonreentrant("lock") +@nonreentrant def __init__(): foo: uint256 = 0 """ diff --git a/tests/functional/syntax/exceptions/test_structure_exception.py b/tests/functional/syntax/exceptions/test_structure_exception.py index afc7a35012..e530487fea 100644 --- a/tests/functional/syntax/exceptions/test_structure_exception.py +++ b/tests/functional/syntax/exceptions/test_structure_exception.py @@ -44,42 +44,11 @@ def foo() -> int128: return x.codesize() """, """ -@external -@nonreentrant("B") -@nonreentrant("C") -def double_nonreentrant(): - pass - """, - """ struct X: int128[5]: int128[7] """, """ @external -@nonreentrant(" ") -def invalid_nonreentrant_key(): - pass - """, - """ -@external -@nonreentrant("") -def invalid_nonreentrant_key(): - pass - """, - """ -@external -@nonreentrant("123") -def invalid_nonreentrant_key(): - pass - """, - """ -@external -@nonreentrant("!123abcd") -def invalid_nonreentrant_key(): - pass - """, - """ -@external def foo(): true: int128 = 3 """, diff --git a/tests/functional/syntax/signatures/test_invalid_function_decorators.py b/tests/functional/syntax/signatures/test_invalid_function_decorators.py index b3d4219a2d..a7a500efc7 100644 --- a/tests/functional/syntax/signatures/test_invalid_function_decorators.py +++ b/tests/functional/syntax/signatures/test_invalid_function_decorators.py @@ -7,10 +7,23 @@ """ @external @pure -@nonreentrant('lock') +@nonreentrant def nonreentrant_foo() -> uint256: return 1 + """, """ +@external +@nonreentrant +@nonreentrant +def nonreentrant_foo() -> uint256: + return 1 + """, + """ +@external +@nonreentrant("foo") +def nonreentrant_foo() -> uint256: + return 1 + """, ] diff --git a/tests/unit/cli/storage_layout/test_storage_layout.py b/tests/unit/cli/storage_layout/test_storage_layout.py index f0ee25f747..9724dd723c 100644 --- a/tests/unit/cli/storage_layout/test_storage_layout.py +++ b/tests/unit/cli/storage_layout/test_storage_layout.py @@ -6,18 +6,18 @@ def test_storage_layout(): foo: HashMap[address, uint256] @external -@nonreentrant("foo") +@nonreentrant def public_foo1(): pass @external -@nonreentrant("foo") +@nonreentrant def public_foo2(): pass @internal -@nonreentrant("bar") +@nonreentrant def _bar(): pass @@ -28,12 +28,12 @@ def _bar(): bar: uint256 @external -@nonreentrant("bar") +@nonreentrant def public_bar(): pass @external -@nonreentrant("foo") +@nonreentrant def public_foo3(): pass """ @@ -41,12 +41,11 @@ def public_foo3(): out = compile_code(code, output_formats=["layout"]) assert out["layout"]["storage_layout"] == { - "nonreentrant.foo": {"type": "nonreentrant lock", "slot": 0}, - "nonreentrant.bar": {"type": "nonreentrant lock", "slot": 1}, - "foo": {"type": "HashMap[address, uint256]", "slot": 2}, - "arr": {"type": "DynArray[uint256, 3]", "slot": 3}, - "baz": {"type": "Bytes[65]", "slot": 7}, - "bar": {"type": "uint256", "slot": 11}, + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, + "foo": {"slot": 1, "type": "HashMap[address, uint256]"}, + "arr": {"slot": 2, "type": "DynArray[uint256, 3]"}, + "baz": {"slot": 6, "type": "Bytes[65]"}, + "bar": {"slot": 10, "type": "uint256"}, } @@ -64,10 +63,13 @@ def __init__(): expected_layout = { "code_layout": { - "DECIMALS": {"length": 32, "offset": 64, "type": "uint8"}, "SYMBOL": {"length": 64, "offset": 0, "type": "String[32]"}, + "DECIMALS": {"length": 32, "offset": 64, "type": "uint8"}, + }, + "storage_layout": { + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, + "name": {"slot": 1, "type": "String[32]"}, }, - "storage_layout": {"name": {"slot": 0, "type": "String[32]"}}, } out = compile_code(code, output_formats=["layout"]) @@ -107,14 +109,15 @@ def __init__(): "code_layout": { "some_immutable": {"length": 352, "offset": 0, "type": "DynArray[uint256, 10]"}, "a_library": { - "DECIMALS": {"length": 32, "offset": 416, "type": "uint8"}, "SYMBOL": {"length": 64, "offset": 352, "type": "String[32]"}, + "DECIMALS": {"length": 32, "offset": 416, "type": "uint8"}, }, }, "storage_layout": { - "counter": {"slot": 0, "type": "uint256"}, - "counter2": {"slot": 1, "type": "uint256"}, - "a_library": {"supply": {"slot": 2, "type": "uint256"}}, + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, + "counter": {"slot": 1, "type": "uint256"}, + "counter2": {"slot": 2, "type": "uint256"}, + "a_library": {"supply": {"slot": 3, "type": "uint256"}}, }, } @@ -160,9 +163,10 @@ def __init__(): }, }, "storage_layout": { - "counter": {"slot": 0, "type": "uint256"}, - "a_library": {"supply": {"slot": 1, "type": "uint256"}}, - "counter2": {"slot": 2, "type": "uint256"}, + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, + "counter": {"slot": 1, "type": "uint256"}, + "a_library": {"supply": {"slot": 2, "type": "uint256"}}, + "counter2": {"slot": 3, "type": "uint256"}, }, } @@ -171,7 +175,8 @@ def __init__(): def test_storage_layout_module_uses(make_input_bundle): - # test module storage layout, with initializes/uses + # test module storage layout, with initializes/uses and a nonreentrant + # lock lib1 = """ supply: uint256 SYMBOL: immutable(String[32]) @@ -197,6 +202,11 @@ def __init__(s: uint256): @internal def decimals() -> uint8: return lib1.DECIMALS + +@external +@nonreentrant +def foo(): + pass """ code = """ import lib1 as a_library @@ -218,6 +228,11 @@ def __init__(): some_immutable = [1, 2, 3] lib2.__init__(17) + +@external +@nonreentrant +def bar(): + pass """ input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) @@ -231,10 +246,11 @@ def __init__(): }, }, "storage_layout": { - "counter": {"slot": 0, "type": "uint256"}, - "lib2": {"storage_variable": {"slot": 1, "type": "uint256"}}, - "counter2": {"slot": 2, "type": "uint256"}, - "a_library": {"supply": {"slot": 3, "type": "uint256"}}, + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, + "counter": {"slot": 1, "type": "uint256"}, + "lib2": {"storage_variable": {"slot": 2, "type": "uint256"}}, + "counter2": {"slot": 3, "type": "uint256"}, + "a_library": {"supply": {"slot": 4, "type": "uint256"}}, }, } @@ -309,12 +325,13 @@ def foo() -> uint256: }, }, "storage_layout": { - "counter": {"slot": 0, "type": "uint256"}, + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, + "counter": {"slot": 1, "type": "uint256"}, "lib2": { - "lib1": {"supply": {"slot": 1, "type": "uint256"}}, - "storage_variable": {"slot": 2, "type": "uint256"}, + "lib1": {"supply": {"slot": 2, "type": "uint256"}}, + "storage_variable": {"slot": 3, "type": "uint256"}, }, - "counter2": {"slot": 3, "type": "uint256"}, + "counter2": {"slot": 4, "type": "uint256"}, }, } diff --git a/tests/unit/cli/storage_layout/test_storage_layout_overrides.py b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py index f4c11b7ae6..707c94c3fc 100644 --- a/tests/unit/cli/storage_layout/test_storage_layout_overrides.py +++ b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py @@ -1,3 +1,5 @@ +import re + import pytest from vyper.compiler import compile_code @@ -28,18 +30,18 @@ def test_storage_layout_for_more_complex(): foo: HashMap[address, uint256] @external -@nonreentrant("foo") +@nonreentrant def public_foo1(): pass @external -@nonreentrant("foo") +@nonreentrant def public_foo2(): pass @internal -@nonreentrant("bar") +@nonreentrant def _bar(): pass @@ -48,19 +50,18 @@ def _bar(): bar: uint256 @external -@nonreentrant("bar") +@nonreentrant def public_bar(): pass @external -@nonreentrant("foo") +@nonreentrant def public_foo3(): pass """ storage_layout_override = { - "nonreentrant.foo": {"type": "nonreentrant lock", "slot": 8}, - "nonreentrant.bar": {"type": "nonreentrant lock", "slot": 7}, + "$.nonreentrant_key": {"type": "nonreentrant lock", "slot": 8}, "foo": {"type": "HashMap[address, uint256]", "slot": 1}, "baz": {"type": "Bytes[65]", "slot": 2}, "bar": {"type": "uint256", "slot": 6}, @@ -110,6 +111,25 @@ def test_overflow(): ) +def test_override_nonreentrant_slot(): + code = """ +@nonreentrant +@external +def foo(): + pass + """ + + storage_layout_override = {"$.nonreentrant_key": {"slot": 2**256, "type": "nonreentrant key"}} + + exception_regex = re.escape( + f"Invalid storage slot for var $.nonreentrant_key, out of bounds: {2**256}" + ) + with pytest.raises(StorageLayoutException, match=exception_regex): + compile_code( + code, output_formats=["layout"], storage_layout_override=storage_layout_override + ) + + def test_incomplete_overrides(): code = """ name: public(String[64]) diff --git a/tests/unit/semantics/test_storage_slots.py b/tests/unit/semantics/test_storage_slots.py index 3620ef64b9..1dc70fd1ba 100644 --- a/tests/unit/semantics/test_storage_slots.py +++ b/tests/unit/semantics/test_storage_slots.py @@ -47,15 +47,9 @@ def __init__(): self.foo[1] = [123, 456, 789] @external -@nonreentrant('lock') +@nonreentrant def with_lock(): pass - - -@external -@nonreentrant('otherlock') -def with_other_lock(): - pass """ @@ -84,7 +78,6 @@ def test_reentrancy_lock(get_contract): # if re-entrancy locks are incorrectly placed within storage, these # calls will either revert or correupt the data that we read later c.with_lock() - c.with_other_lock() assert c.a() == ("ok", [4, 5, 6]) assert [c.b(i) for i in range(2)] == [7, 8] @@ -105,7 +98,7 @@ def test_reentrancy_lock(get_contract): def test_allocator_overflow(get_contract): code = """ -x: uint256 +# --> global nonreentrancy slot allocated here <-- y: uint256[max_value(uint256)] """ with pytest.raises( diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index 604bc6b594..bb4322c7b2 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -43,10 +43,15 @@ def __setitem__(self, k, v): super().__setitem__(k, v) +# some name that the user cannot assign to a variable +GLOBAL_NONREENTRANT_KEY = "$.nonreentrant_key" + + class SimpleAllocator: def __init__(self, max_slot: int = 2**256, starting_slot: int = 0): # Allocate storage slots from 0 # note storage is word-addressable, not byte-addressable + self._starting_slot = starting_slot self._slot = starting_slot self._max_slot = max_slot @@ -61,12 +66,19 @@ def allocate_slot(self, n, var_name, node=None): self._slot += n return ret + def allocate_global_nonreentrancy_slot(self): + slot = self.allocate_slot(1, GLOBAL_NONREENTRANT_KEY) + assert slot == self._starting_slot + return slot + class Allocators: storage_allocator: SimpleAllocator transient_storage_allocator: SimpleAllocator immutables_allocator: SimpleAllocator + _global_nonreentrancy_key_slot: int + def __init__(self): self.storage_allocator = SimpleAllocator(max_slot=2**256) self.transient_storage_allocator = SimpleAllocator(max_slot=2**256) @@ -82,6 +94,16 @@ def get_allocator(self, location: DataLocation): raise CompilerPanic("unreachable") # pragma: nocover + def allocate_global_nonreentrancy_slot(self): + location = get_reentrancy_key_location() + + allocator = self.get_allocator(location) + slot = allocator.allocate_global_nonreentrancy_slot() + self._global_nonreentrancy_key_slot = slot + + def get_global_nonreentrant_key_slot(self): + return self._global_nonreentrancy_key_slot + class OverridingStorageAllocator: """ @@ -127,7 +149,6 @@ def set_storage_slots_with_overrides( Returns the layout as a dict of variable name -> variable info (Doesn't handle modules, or transient storage) """ - ret: InsertableOnceDict[str, dict] = InsertableOnceDict() reserved_slots = OverridingStorageAllocator() @@ -136,15 +157,13 @@ def set_storage_slots_with_overrides( type_ = node._metadata["func_type"] # Ignore functions without non-reentrant - if type_.nonreentrant is None: + if not type_.nonreentrant: continue - variable_name = f"nonreentrant.{type_.nonreentrant}" + variable_name = GLOBAL_NONREENTRANT_KEY # re-entrant key was already identified if variable_name in ret: - _slot = ret[variable_name]["slot"] - type_.set_reentrancy_key_position(VarOffset(_slot)) continue # Expect to find this variable within the storage layout override @@ -210,6 +229,20 @@ def get_reentrancy_key_location() -> DataLocation: } +def _allocate_nonreentrant_keys(vyper_module, allocators): + SLOT = allocators.get_global_nonreentrant_key_slot() + + for node in vyper_module.get_children(vy_ast.FunctionDef): + type_ = node._metadata["func_type"] + if not type_.nonreentrant: + continue + + # a nonreentrant key can appear many times in a module but it + # only takes one slot. after the first time we see it, do not + # increment the storage slot. + type_.set_reentrancy_key_position(VarOffset(SLOT)) + + def _allocate_layout_r( vyper_module: vy_ast.Module, allocators: Allocators = None, immutables_only=False ) -> StorageLayout: @@ -217,42 +250,26 @@ def _allocate_layout_r( Parse module-level Vyper AST to calculate the layout of storage variables. Returns the layout as a dict of variable name -> variable info """ + global_ = False if allocators is None: + global_ = True allocators = Allocators() + # always allocate nonreentrancy slot, so that adding or removing + # reentrancy protection from a contract does not change its layout + allocators.allocate_global_nonreentrancy_slot() ret: defaultdict[str, InsertableOnceDict[str, dict]] = defaultdict(InsertableOnceDict) - for node in vyper_module.get_children(vy_ast.FunctionDef): - if immutables_only: - break - - type_ = node._metadata["func_type"] - if type_.nonreentrant is None: - continue - - variable_name = f"nonreentrant.{type_.nonreentrant}" - reentrancy_key_location = get_reentrancy_key_location() - layout_key = _LAYOUT_KEYS[reentrancy_key_location] - - # a nonreentrant key can appear many times in a module but it - # only takes one slot. after the first time we see it, do not - # increment the storage slot. - if variable_name in ret[layout_key]: - _slot = ret[layout_key][variable_name]["slot"] - type_.set_reentrancy_key_position(VarOffset(_slot)) - continue - - # TODO use one byte - or bit - per reentrancy key - # requires either an extra SLOAD or caching the value of the - # location in memory at entrance - allocator = allocators.get_allocator(reentrancy_key_location) - slot = allocator.allocate_slot(1, variable_name, node) - - type_.set_reentrancy_key_position(VarOffset(slot)) + # tag functions with the global nonreentrant key + if not immutables_only: + _allocate_nonreentrant_keys(vyper_module, allocators) + layout_key = _LAYOUT_KEYS[get_reentrancy_key_location()] # TODO this could have better typing but leave it untyped until # we nail down the format better - ret[layout_key][variable_name] = {"type": "nonreentrant lock", "slot": slot} + if global_ and GLOBAL_NONREENTRANT_KEY not in ret[layout_key]: + slot = allocators.get_global_nonreentrant_key_slot() + ret[layout_key][GLOBAL_NONREENTRANT_KEY] = {"type": "nonreentrant lock", "slot": slot} for node in _get_allocatable(vyper_module): if isinstance(node, vy_ast.InitializesDecl): diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 705470a798..43d553288e 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -5,7 +5,6 @@ from typing import Any, Dict, List, Optional, Tuple from vyper import ast as vy_ast -from vyper.ast.identifiers import validate_identifier from vyper.ast.validation import validate_call_args from vyper.exceptions import ( ArgumentException, @@ -78,8 +77,8 @@ class ContractFunctionT(VyperType): enum indicating the external visibility of a function. state_mutability : StateMutability enum indicating the authority a function has to mutate it's own state. - nonreentrant : Optional[str] - Re-entrancy lock name. + nonreentrant : bool + Whether this function is marked `@nonreentrant` or not """ _is_callable = True @@ -93,7 +92,7 @@ def __init__( function_visibility: FunctionVisibility, state_mutability: StateMutability, from_interface: bool = False, - nonreentrant: Optional[str] = None, + nonreentrant: bool = False, ast_def: Optional[vy_ast.VyperNode] = None, ) -> None: super().__init__() @@ -107,6 +106,9 @@ def __init__( self.nonreentrant = nonreentrant self.from_interface = from_interface + # sanity check, nonreentrant used to be Optional[str] + assert isinstance(self.nonreentrant, bool) + self.ast_def = ast_def self._analysed = False @@ -279,7 +281,7 @@ def from_InterfaceDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": function_visibility, state_mutability, from_interface=True, - nonreentrant=None, + nonreentrant=False, ast_def=funcdef, ) @@ -298,12 +300,10 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": ------- ContractFunctionT """ - function_visibility, state_mutability, nonreentrant_key = _parse_decorators(funcdef) + function_visibility, state_mutability, nonreentrant = _parse_decorators(funcdef) - if nonreentrant_key is not None: - raise FunctionDeclarationException( - "nonreentrant key not allowed in interfaces", funcdef - ) + if nonreentrant: + raise FunctionDeclarationException("`@nonreentrant` not allowed in interfaces", funcdef) if funcdef.name == "__init__": raise FunctionDeclarationException("Constructors cannot appear in interfaces", funcdef) @@ -332,7 +332,7 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": function_visibility, state_mutability, from_interface=True, - nonreentrant=nonreentrant_key, + nonreentrant=nonreentrant, ast_def=funcdef, ) @@ -350,7 +350,7 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": ------- ContractFunctionT """ - function_visibility, state_mutability, nonreentrant_key = _parse_decorators(funcdef) + function_visibility, state_mutability, nonreentrant = _parse_decorators(funcdef) positional_args, keyword_args = _parse_args(funcdef) @@ -403,15 +403,16 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": function_visibility, state_mutability, from_interface=False, - nonreentrant=nonreentrant_key, + nonreentrant=nonreentrant, ast_def=funcdef, ) def set_reentrancy_key_position(self, position: VarOffset) -> None: if hasattr(self, "reentrancy_key_position"): raise CompilerPanic("Position was already assigned") - if self.nonreentrant is None: - raise CompilerPanic(f"No reentrant key {self}") + if not self.nonreentrant: + raise CompilerPanic(f"Not nonreentrant {self}", self.ast_def) + self.reentrancy_key_position = position @classmethod @@ -660,32 +661,30 @@ def _parse_return_type(funcdef: vy_ast.FunctionDef) -> Optional[VyperType]: def _parse_decorators( funcdef: vy_ast.FunctionDef, -) -> tuple[Optional[FunctionVisibility], StateMutability, Optional[str]]: +) -> tuple[Optional[FunctionVisibility], StateMutability, bool]: function_visibility = None state_mutability = None - nonreentrant_key = None + nonreentrant_node = None for decorator in funcdef.decorator_list: if isinstance(decorator, vy_ast.Call): - if nonreentrant_key is not None: - raise StructureException( - "nonreentrant decorator is already set with key: " f"{nonreentrant_key}", - funcdef, - ) - - if decorator.get("func.id") != "nonreentrant": - raise StructureException("Decorator is not callable", decorator) - if len(decorator.args) != 1 or not isinstance(decorator.args[0], vy_ast.Str): - raise StructureException( - "@nonreentrant name must be given as a single string literal", decorator - ) + msg = "Decorator is not callable" + hint = None + if decorator.get("func.id") == "nonreentrant": + hint = "use `@nonreentrant` with no arguments. the " + hint += "`@nonreentrant` decorator does not accept any " + hint += "arguments since vyper 0.4.0." + raise StructureException(msg, decorator, hint=hint) + + if decorator.get("id") == "nonreentrant": + if nonreentrant_node is not None: + raise StructureException("nonreentrant decorator is already set", nonreentrant_node) if funcdef.name == "__init__": - msg = "Nonreentrant decorator disallowed on `__init__`" + msg = "`@nonreentrant` decorator disallowed on `__init__`" raise FunctionDeclarationException(msg, decorator) - nonreentrant_key = decorator.args[0].value - validate_identifier(nonreentrant_key, decorator.args[0]) + nonreentrant_node = decorator elif isinstance(decorator, vy_ast.Name): if FunctionVisibility.is_valid_value(decorator.id): @@ -726,12 +725,13 @@ def _parse_decorators( # default to nonpayable state_mutability = StateMutability.NONPAYABLE - if state_mutability == StateMutability.PURE and nonreentrant_key is not None: - raise StructureException("Cannot use reentrancy guard on pure functions", funcdef) + if state_mutability == StateMutability.PURE and nonreentrant_node is not None: + raise StructureException("Cannot use reentrancy guard on pure functions", nonreentrant_node) # assert function_visibility is not None # mypy # assert state_mutability is not None # mypy - return function_visibility, state_mutability, nonreentrant_key + nonreentrant = nonreentrant_node is not None + return function_visibility, state_mutability, nonreentrant def _parse_args( From b3e2fd9c67eb43caaef04ee494368ac618763dc0 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 13 Feb 2024 12:36:07 -0500 Subject: [PATCH 11/12] update a comment --- vyper/semantics/analysis/local.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 882989d776..cefbbb01d9 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -493,7 +493,7 @@ def _analyse_list_iter(self, iter_node, target_type): except (InvalidType, StructureException): raise InvalidType("Not an iterable type", iter_node) - # CMC 2024-02-09 TODO: use validate_expected_type once we have DArrays + # CMC 2024-02-09 TODO: use infer_type once we have DArrays # with generic length. if not isinstance(iter_type, (DArrayT, SArrayT)): raise InvalidType("Not an iterable type", iter_node) From 3fd9fb864100702eb01ed2c45f4cfb5c0ba8e9df Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 13 Feb 2024 12:41:13 -0500 Subject: [PATCH 12/12] improve type inference for revert reason strings --- vyper/semantics/analysis/local.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index cefbbb01d9..d787ba6a41 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -313,17 +313,20 @@ def visit_AnnAssign(self, node): self.expr_visitor.visit(node.target, typ) def _validate_revert_reason(self, msg_node: vy_ast.VyperNode) -> None: + if isinstance(msg_node, vy_ast.Name) and msg_node.id == "UNREACHABLE": + # CMC 2023-10-19 nice to have: tag UNREACHABLE nodes with a special type + return + if isinstance(msg_node, vy_ast.Str): if not msg_node.value.strip(): raise StructureException("Reason string cannot be empty", msg_node) - self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node)) - elif not (isinstance(msg_node, vy_ast.Name) and msg_node.id == "UNREACHABLE"): - try: - _ = infer_type(msg_node, StringT(1024)) - except TypeMismatch as e: - raise InvalidType("revert reason must fit within String[1024]") from e - self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node)) - # CMC 2023-10-19 nice to have: tag UNREACHABLE nodes with a special type + try: + self.expr_visitor.visit(msg_node, StringT.any()) + except TypeMismatch as e: + # improve the error message + msg = "reason must be a string or the special `UNREACHABLE` value" + raise TypeMismatch(msg, msg_node) from e + def visit_Assert(self, node): if node.msg: