diff --git a/Makefile b/Makefile index 645b800e79..649b381012 100644 --- a/Makefile +++ b/Makefile @@ -17,11 +17,8 @@ dev-init: test: pytest -mypy: - tox -e mypy - lint: - tox -e lint + tox -e lint,mypy docs: rm -f docs/vyper.rst diff --git a/examples/auctions/blind_auction.vy b/examples/auctions/blind_auction.vy index 597aed57c7..966565138f 100644 --- a/examples/auctions/blind_auction.vy +++ b/examples/auctions/blind_auction.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + # Blind Auction. Adapted to Vyper from [Solidity by Example](https://github.com/ethereum/solidity/blob/develop/docs/solidity-by-example.rst#blind-auction-1) struct Bid: @@ -36,7 +38,7 @@ pendingReturns: HashMap[address, uint256] # Create a blinded auction with `_biddingTime` seconds bidding time and # `_revealTime` seconds reveal time on behalf of the beneficiary address # `_beneficiary`. -@external +@deploy def __init__(_beneficiary: address, _biddingTime: uint256, _revealTime: uint256): self.beneficiary = _beneficiary self.biddingEnd = block.timestamp + _biddingTime diff --git a/examples/auctions/simple_open_auction.vy b/examples/auctions/simple_open_auction.vy index 6d5ce06f17..499e12af16 100644 --- a/examples/auctions/simple_open_auction.vy +++ b/examples/auctions/simple_open_auction.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + # Open Auction # Auction params @@ -19,7 +21,7 @@ pendingReturns: public(HashMap[address, uint256]) # Create a simple auction with `_auction_start` and # `_bidding_time` seconds bidding time on behalf of the # beneficiary address `_beneficiary`. -@external +@deploy def __init__(_beneficiary: address, _auction_start: uint256, _bidding_time: uint256): self.beneficiary = _beneficiary self.auctionStart = _auction_start # auction start time can be in the past, present or future diff --git a/examples/crowdfund.vy b/examples/crowdfund.vy index 6d07e15bc4..50ec005924 100644 --- a/examples/crowdfund.vy +++ b/examples/crowdfund.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + ########################################################################### ## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! ########################################################################### @@ -11,7 +13,7 @@ goal: public(uint256) timelimit: public(uint256) # Setup global variables -@external +@deploy def __init__(_beneficiary: address, _goal: uint256, _timelimit: uint256): self.beneficiary = _beneficiary self.deadline = block.timestamp + _timelimit diff --git a/examples/factory/Exchange.vy b/examples/factory/Exchange.vy index 77f47984bc..e66c60743a 100644 --- a/examples/factory/Exchange.vy +++ b/examples/factory/Exchange.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + from ethereum.ercs import ERC20 @@ -9,7 +11,7 @@ token: public(ERC20) factory: Factory -@external +@deploy def __init__(_token: ERC20, _factory: Factory): self.token = _token self.factory = _factory diff --git a/examples/factory/Factory.vy b/examples/factory/Factory.vy index 50e7a81bf6..4fec723197 100644 --- a/examples/factory/Factory.vy +++ b/examples/factory/Factory.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + from ethereum.ercs import ERC20 interface Exchange: @@ -11,7 +13,7 @@ exchange_codehash: public(bytes32) exchanges: public(HashMap[ERC20, Exchange]) -@external +@deploy def __init__(_exchange_codehash: bytes32): # Register the exchange code hash during deployment of the factory self.exchange_codehash = _exchange_codehash diff --git a/examples/market_maker/on_chain_market_maker.vy b/examples/market_maker/on_chain_market_maker.vy index 4f9859584c..74b1307dc1 100644 --- a/examples/market_maker/on_chain_market_maker.vy +++ b/examples/market_maker/on_chain_market_maker.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + from ethereum.ercs import ERC20 diff --git a/examples/name_registry/name_registry.vy b/examples/name_registry/name_registry.vy index 7152851dac..937b41856b 100644 --- a/examples/name_registry/name_registry.vy +++ b/examples/name_registry/name_registry.vy @@ -1,3 +1,4 @@ +#pragma version >0.3.10 registry: HashMap[Bytes[100], address] diff --git a/examples/safe_remote_purchase/safe_remote_purchase.vy b/examples/safe_remote_purchase/safe_remote_purchase.vy index edc2163b85..91f0159a2d 100644 --- a/examples/safe_remote_purchase/safe_remote_purchase.vy +++ b/examples/safe_remote_purchase/safe_remote_purchase.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + # Safe Remote Purchase # Originally from # https://github.com/ethereum/solidity/blob/develop/docs/solidity-by-example.rst @@ -19,7 +21,7 @@ buyer: public(address) unlocked: public(bool) ended: public(bool) -@external +@deploy @payable def __init__(): assert (msg.value % 2) == 0 diff --git a/examples/stock/company.vy b/examples/stock/company.vy index 6293e6eea4..355432830d 100644 --- a/examples/stock/company.vy +++ b/examples/stock/company.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + # Financial events the contract logs event Transfer: @@ -27,7 +29,7 @@ price: public(uint256) holdings: HashMap[address, uint256] # Set up the company. -@external +@deploy def __init__(_company: address, _total_shares: uint256, initial_price: uint256): assert _total_shares > 0 assert initial_price > 0 diff --git a/examples/storage/advanced_storage.vy b/examples/storage/advanced_storage.vy index 2ba50280d7..42a455cbf1 100644 --- a/examples/storage/advanced_storage.vy +++ b/examples/storage/advanced_storage.vy @@ -1,10 +1,12 @@ +#pragma version >0.3.10 + event DataChange: setter: indexed(address) value: int128 storedData: public(int128) -@external +@deploy def __init__(_x: int128): self.storedData = _x diff --git a/examples/storage/storage.vy b/examples/storage/storage.vy index 7d05e4708c..30f570f212 100644 --- a/examples/storage/storage.vy +++ b/examples/storage/storage.vy @@ -1,9 +1,11 @@ +#pragma version >0.3.10 + storedData: public(int128) -@external +@deploy def __init__(_x: int128): self.storedData = _x @external def set(_x: int128): - self.storedData = _x \ No newline at end of file + self.storedData = _x diff --git a/examples/tokens/ERC1155ownable.vy b/examples/tokens/ERC1155ownable.vy index d1e88dcd04..d88d459d64 100644 --- a/examples/tokens/ERC1155ownable.vy +++ b/examples/tokens/ERC1155ownable.vy @@ -1,8 +1,9 @@ +#pragma version >0.3.10 + ########################################################################### ## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! ########################################################################### -# @version >=0.3.4 """ @dev example implementation of ERC-1155 non-fungible token standard ownable, with approval, OPENSEA compatible (name, symbol) @author Dr. Pixel (github: @Doc-Pixel) @@ -122,7 +123,7 @@ interface IERC1155MetadataURI: ############### functions ############### -@external +@deploy def __init__(name: String[128], symbol: String[16], uri: String[MAX_URI_LENGTH], contractUri: String[MAX_URI_LENGTH]): """ @dev contract initialization on deployment diff --git a/examples/tokens/ERC20.vy b/examples/tokens/ERC20.vy index 77550c3f5a..0e94b32b9d 100644 --- a/examples/tokens/ERC20.vy +++ b/examples/tokens/ERC20.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + ########################################################################### ## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! ########################################################################### @@ -38,7 +40,7 @@ totalSupply: public(uint256) minter: address -@external +@deploy def __init__(_name: String[32], _symbol: String[32], _decimals: uint8, _supply: uint256): init_supply: uint256 = _supply * 10 ** convert(_decimals, uint256) self.name = _name diff --git a/examples/tokens/ERC4626.vy b/examples/tokens/ERC4626.vy index 73721fdb98..699b5edd42 100644 --- a/examples/tokens/ERC4626.vy +++ b/examples/tokens/ERC4626.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + # NOTE: Copied from https://github.com/fubuloubu/ERC4626/blob/1a10b051928b11eeaad15d80397ed36603c2a49b/contracts/VyperVault.vy # example implementation of an ERC4626 vault @@ -50,7 +52,7 @@ event Withdraw: shares: uint256 -@external +@deploy def __init__(asset: ERC20): self.asset = asset diff --git a/examples/tokens/ERC721.vy b/examples/tokens/ERC721.vy index d3a8d1f13d..70dff96051 100644 --- a/examples/tokens/ERC721.vy +++ b/examples/tokens/ERC721.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + ########################################################################### ## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! ########################################################################### @@ -82,7 +84,7 @@ SUPPORTED_INTERFACES: constant(bytes4[2]) = [ 0x80ac58cd, ] -@external +@deploy def __init__(): """ @dev Contract constructor. diff --git a/examples/voting/ballot.vy b/examples/voting/ballot.vy index 107716accf..daaf712e0f 100644 --- a/examples/voting/ballot.vy +++ b/examples/voting/ballot.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + # Voting with delegation. # Information about voters @@ -50,7 +52,7 @@ def directlyVoted(addr: address) -> bool: # Setup global variables -@external +@deploy def __init__(_proposalNames: bytes32[2]): self.chairperson = msg.sender self.voterCount = 0 diff --git a/examples/wallet/wallet.vy b/examples/wallet/wallet.vy index 231f538ecf..7e92c7e89c 100644 --- a/examples/wallet/wallet.vy +++ b/examples/wallet/wallet.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + ########################################################################### ## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! ########################################################################### @@ -12,7 +14,7 @@ threshold: int128 seq: public(int128) -@external +@deploy def __init__(_owners: address[5], _threshold: int128): for i: uint256 in range(5): if _owners[i] != empty(address): diff --git a/tests/functional/builtins/codegen/test_abi.py b/tests/functional/builtins/codegen/test_abi.py index 4ddfcf50c1..335f728a37 100644 --- a/tests/functional/builtins/codegen/test_abi.py +++ b/tests/functional/builtins/codegen/test_abi.py @@ -8,14 +8,14 @@ """ x: int128 -@external +@deploy def __init__(): self.x = 1 """, """ x: int128 -@external +@deploy def __init__(): pass """, diff --git a/tests/functional/builtins/codegen/test_abi_decode.py b/tests/functional/builtins/codegen/test_abi_decode.py index 69bfef63ea..96cbbe4c2d 100644 --- a/tests/functional/builtins/codegen/test_abi_decode.py +++ b/tests/functional/builtins/codegen/test_abi_decode.py @@ -224,7 +224,7 @@ def test_side_effects_evaluation(get_contract): contract_1 = """ counter: uint256 -@external +@deploy def __init__(): self.counter = 0 diff --git a/tests/functional/builtins/codegen/test_abi_encode.py b/tests/functional/builtins/codegen/test_abi_encode.py index f4b7d57a04..8709e31470 100644 --- a/tests/functional/builtins/codegen/test_abi_encode.py +++ b/tests/functional/builtins/codegen/test_abi_encode.py @@ -263,7 +263,7 @@ def test_side_effects_evaluation(get_contract): contract_1 = """ counter: uint256 -@external +@deploy def __init__(): self.counter = 0 diff --git a/tests/functional/builtins/codegen/test_ceil.py b/tests/functional/builtins/codegen/test_ceil.py index daa9cb7c1b..191e2adfef 100644 --- a/tests/functional/builtins/codegen/test_ceil.py +++ b/tests/functional/builtins/codegen/test_ceil.py @@ -6,7 +6,7 @@ def test_ceil(get_contract_with_gas_estimation): code = """ x: decimal -@external +@deploy def __init__(): self.x = 504.0000000001 @@ -53,7 +53,7 @@ def test_ceil_negative(get_contract_with_gas_estimation): code = """ x: decimal -@external +@deploy def __init__(): self.x = -504.0000000001 diff --git a/tests/functional/builtins/codegen/test_concat.py b/tests/functional/builtins/codegen/test_concat.py index 7354515989..37bdaaaf7b 100644 --- a/tests/functional/builtins/codegen/test_concat.py +++ b/tests/functional/builtins/codegen/test_concat.py @@ -79,7 +79,7 @@ def test_concat_buffer2(get_contract): code = """ i: immutable(int256) -@external +@deploy def __init__(): i = -1 s: String[2] = concat("a", "b") @@ -99,7 +99,7 @@ def test_concat_buffer3(get_contract): s2: String[33] s3: String[34] -@external +@deploy def __init__(): self.s = "a" self.s2 = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" # 33*'a' diff --git a/tests/functional/builtins/codegen/test_create_functions.py b/tests/functional/builtins/codegen/test_create_functions.py index afa729ac8a..0aa718157c 100644 --- a/tests/functional/builtins/codegen/test_create_functions.py +++ b/tests/functional/builtins/codegen/test_create_functions.py @@ -214,7 +214,7 @@ def test_create_from_blueprint_bad_code_offset( deployer_code = """ BLUEPRINT: immutable(address) -@external +@deploy def __init__(blueprint_address: address): BLUEPRINT = blueprint_address @@ -269,7 +269,7 @@ def test_create_from_blueprint_args( FOO: immutable(String[128]) BAR: immutable(Bar) -@external +@deploy def __init__(foo: String[128], bar: Bar): FOO = foo BAR = bar @@ -450,7 +450,7 @@ def test_create_from_blueprint_complex_value( code = """ var: uint256 -@external +@deploy @payable def __init__(x: uint256): self.var = x @@ -507,7 +507,7 @@ def test_create_from_blueprint_complex_salt_raw_args( code = """ var: uint256 -@external +@deploy @payable def __init__(x: uint256): self.var = x @@ -565,7 +565,7 @@ def test_create_from_blueprint_complex_salt_no_constructor_args( code = """ var: uint256 -@external +@deploy @payable def __init__(): self.var = 12 diff --git a/tests/functional/builtins/codegen/test_ecrecover.py b/tests/functional/builtins/codegen/test_ecrecover.py index 8571948c3d..ce24868afe 100644 --- a/tests/functional/builtins/codegen/test_ecrecover.py +++ b/tests/functional/builtins/codegen/test_ecrecover.py @@ -68,7 +68,7 @@ def test_invalid_signature2(get_contract): owner: immutable(address) -@external +@deploy def __init__(): owner = 0x7E5F4552091A69125d5DfCb7b8C2659029395Bdf diff --git a/tests/functional/builtins/codegen/test_floor.py b/tests/functional/builtins/codegen/test_floor.py index d2fd993785..5caffd5551 100644 --- a/tests/functional/builtins/codegen/test_floor.py +++ b/tests/functional/builtins/codegen/test_floor.py @@ -6,7 +6,7 @@ def test_floor(get_contract_with_gas_estimation): code = """ x: decimal -@external +@deploy def __init__(): self.x = 504.0000000001 @@ -55,7 +55,7 @@ def test_floor_negative(get_contract_with_gas_estimation): code = """ x: decimal -@external +@deploy def __init__(): self.x = -504.0000000001 diff --git a/tests/functional/builtins/codegen/test_raw_call.py b/tests/functional/builtins/codegen/test_raw_call.py index b30a94502d..e5201e9bb2 100644 --- a/tests/functional/builtins/codegen/test_raw_call.py +++ b/tests/functional/builtins/codegen/test_raw_call.py @@ -137,7 +137,7 @@ def set_owner(i: int128, o: address): owners: public(address[5]) -@external +@deploy def __init__(_owner_setter: address): self.owner_setter_contract = _owner_setter diff --git a/tests/functional/builtins/codegen/test_slice.py b/tests/functional/builtins/codegen/test_slice.py index 80936bbf82..0c5a8fc485 100644 --- a/tests/functional/builtins/codegen/test_slice.py +++ b/tests/functional/builtins/codegen/test_slice.py @@ -57,7 +57,7 @@ def test_slice_immutable( IMMUTABLE_BYTES: immutable(Bytes[{length_bound}]) IMMUTABLE_SLICE: immutable(Bytes[{length_bound}]) -@external +@deploy def __init__(inp: Bytes[{length_bound}], start: uint256, length: uint256): IMMUTABLE_BYTES = inp IMMUTABLE_SLICE = slice(IMMUTABLE_BYTES, {_start}, {_length}) @@ -119,7 +119,7 @@ def test_slice_bytes_fuzz( elif location == "code": preamble = f""" IMMUTABLE_BYTES: immutable(Bytes[{length_bound}]) -@external +@deploy def __init__(foo: Bytes[{length_bound}]): IMMUTABLE_BYTES = foo """ @@ -230,7 +230,7 @@ def test_slice_immutable_length_arg(get_contract_with_gas_estimation): code = """ LENGTH: immutable(uint256) -@external +@deploy def __init__(): LENGTH = 5 @@ -314,7 +314,7 @@ def f() -> bytes32: """ foo: bytes32 -@external +@deploy def __init__(): self.foo = 0x000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f @@ -325,7 +325,7 @@ def bar() -> Bytes[{length}]: """ foo: bytes32 -@external +@deploy def __init__(): self.foo = 0x000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f diff --git a/tests/functional/builtins/folding/test_bitwise.py b/tests/functional/builtins/folding/test_bitwise.py index c1ff7674bb..f63ef8484a 100644 --- a/tests/functional/builtins/folding/test_bitwise.py +++ b/tests/functional/builtins/folding/test_bitwise.py @@ -1,9 +1,9 @@ import pytest -from hypothesis import given, settings +from hypothesis import example, given, settings from hypothesis import strategies as st from tests.utils import parse_and_fold -from vyper.exceptions import InvalidType, OverflowException +from vyper.exceptions import OverflowException, TypeMismatch from vyper.semantics.analysis.utils import validate_expected_type from vyper.semantics.types.shortcuts import INT256_T, UINT256_T from vyper.utils import unsigned_to_signed @@ -66,9 +66,10 @@ def foo(a: uint256, b: uint256) -> uint256: @pytest.mark.fuzzing -@settings(max_examples=50) +@settings(max_examples=51) @pytest.mark.parametrize("op", ["<<", ">>"]) @given(a=st_sint256, b=st.integers(min_value=0, max_value=256)) +@example(a=128, b=248) # throws TypeMismatch def test_bitwise_shift_signed(get_contract, a, b, op): source = f""" @external @@ -84,7 +85,7 @@ def foo(a: int256, b: uint256) -> int256: validate_expected_type(new_node, INT256_T) # force bounds check # compile time behavior does not match runtime behavior. # compile-time will throw on OOB, runtime will wrap. - except (InvalidType, OverflowException): + except (TypeMismatch, OverflowException): # check the wrapped value matches runtime assert op == "<<" assert contract.foo(a, b) == unsigned_to_signed((a << b) % (2**256), 256) diff --git a/tests/functional/codegen/calling_convention/test_default_function.py b/tests/functional/codegen/calling_convention/test_default_function.py index cf55607877..411f38eac9 100644 --- a/tests/functional/codegen/calling_convention/test_default_function.py +++ b/tests/functional/codegen/calling_convention/test_default_function.py @@ -2,7 +2,7 @@ def test_throw_on_sending(w3, tx_failed, get_contract_with_gas_estimation): code = """ x: public(int128) -@external +@deploy def __init__(): self.x = 123 """ diff --git a/tests/functional/codegen/calling_convention/test_erc20_abi.py b/tests/functional/codegen/calling_convention/test_erc20_abi.py index b9dc5c663f..59c4131fb2 100644 --- a/tests/functional/codegen/calling_convention/test_erc20_abi.py +++ b/tests/functional/codegen/calling_convention/test_erc20_abi.py @@ -33,7 +33,7 @@ def allowance(_owner: address, _spender: address) -> uint256: nonpayable token_address: ERC20Contract -@external +@deploy def __init__(token_addr: address): self.token_address = ERC20Contract(token_addr) diff --git a/tests/functional/codegen/calling_convention/test_external_contract_calls.py b/tests/functional/codegen/calling_convention/test_external_contract_calls.py index a7cf4d0ecf..8b3f30b5a5 100644 --- a/tests/functional/codegen/calling_convention/test_external_contract_calls.py +++ b/tests/functional/codegen/calling_convention/test_external_contract_calls.py @@ -41,7 +41,7 @@ def test_complicated_external_contract_calls(get_contract, get_contract_with_gas contract_1 = """ lucky: public(int128) -@external +@deploy def __init__(_lucky: int128): self.lucky = _lucky @@ -898,26 +898,31 @@ def set_lucky(arg1: address, arg2: int128): print("Successfully executed an external contract call state change") -def test_constant_external_contract_call_cannot_change_state( - assert_compile_failed, get_contract_with_gas_estimation -): +def test_constant_external_contract_call_cannot_change_state(): c = """ interface Foo: def set_lucky(_lucky: int128) -> int128: nonpayable @external @view -def set_lucky_expr(arg1: address, arg2: int128): +def set_lucky_stmt(arg1: address, arg2: int128): Foo(arg1).set_lucky(arg2) + """ + with pytest.raises(StateAccessViolation): + compile_code(c) + + c2 = """ +interface Foo: + def set_lucky(_lucky: int128) -> int128: nonpayable @external @view -def set_lucky_stmt(arg1: address, arg2: int128) -> int128: +def set_lucky_expr(arg1: address, arg2: int128) -> int128: return Foo(arg1).set_lucky(arg2) """ - assert_compile_failed(lambda: get_contract_with_gas_estimation(c), StateAccessViolation) - print("Successfully blocked an external contract call from a constant function") + with pytest.raises(StateAccessViolation): + compile_code(c2) def test_external_contract_can_be_changed_based_on_address(get_contract): @@ -968,7 +973,7 @@ def test_external_contract_calls_with_public_globals(get_contract): contract_1 = """ lucky: public(int128) -@external +@deploy def __init__(_lucky: int128): self.lucky = _lucky """ @@ -994,7 +999,7 @@ def test_external_contract_calls_with_multiple_contracts(get_contract): contract_1 = """ lucky: public(int128) -@external +@deploy def __init__(_lucky: int128): self.lucky = _lucky """ @@ -1008,7 +1013,7 @@ def lucky() -> int128: view magic_number: public(int128) -@external +@deploy def __init__(arg1: address): self.magic_number = Foo(arg1).lucky() """ @@ -1020,7 +1025,7 @@ def magic_number() -> int128: view best_number: public(int128) -@external +@deploy def __init__(arg1: address): self.best_number = Bar(arg1).magic_number() """ @@ -1145,7 +1150,7 @@ def test_invalid_contract_reference_declaration(tx_failed, get_contract): best_number: public(int128) -@external +@deploy def __init__(): pass """ diff --git a/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py b/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py index e6b2402016..aa7130fd6a 100644 --- a/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py +++ b/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py @@ -20,7 +20,7 @@ def set_lucky(_lucky: int128): view modifiable_bar_contract: ModBar static_bar_contract: ConstBar -@external +@deploy def __init__(contract_address: address): self.modifiable_bar_contract = ModBar(contract_address) self.static_bar_contract = ConstBar(contract_address) @@ -64,7 +64,7 @@ def set_lucky(_lucky: int128) -> int128: view modifiable_bar_contract: ModBar static_bar_contract: ConstBar -@external +@deploy def __init__(contract_address: address): self.modifiable_bar_contract = ModBar(contract_address) self.static_bar_contract = ConstBar(contract_address) @@ -108,7 +108,7 @@ def set_lucky(_lucky: int128): view modifiable_bar_contract: ModBar static_bar_contract: ConstBar -@external +@deploy def __init__(contract_address: address): self.modifiable_bar_contract = ModBar(contract_address) self.static_bar_contract = ConstBar(contract_address) @@ -134,7 +134,7 @@ def static_set_lucky(_lucky: int128): view modifiable_bar_contract: ModBar static_bar_contract: ConstBar -@external +@deploy def __init__(contract_address: address): self.modifiable_bar_contract = ModBar(contract_address) self.static_bar_contract = ConstBar(contract_address) diff --git a/tests/functional/codegen/calling_convention/test_return_tuple.py b/tests/functional/codegen/calling_convention/test_return_tuple.py index 266555ead6..74929c9496 100644 --- a/tests/functional/codegen/calling_convention/test_return_tuple.py +++ b/tests/functional/codegen/calling_convention/test_return_tuple.py @@ -16,7 +16,7 @@ def test_return_type(get_contract_with_gas_estimation): c: int128 chunk: Chunk -@external +@deploy def __init__(): self.chunk.a = b"hello" self.chunk.b = b"world" diff --git a/tests/functional/codegen/features/decorators/test_payable.py b/tests/functional/codegen/features/decorators/test_payable.py index ced58e1af0..955501a0e6 100644 --- a/tests/functional/codegen/features/decorators/test_payable.py +++ b/tests/functional/codegen/features/decorators/test_payable.py @@ -122,7 +122,7 @@ def bar() -> bool: """, """ # payable init function -@external +@deploy @payable def __init__(): a: int128 = 1 @@ -279,7 +279,7 @@ def baz() -> bool: """, """ # init function -@external +@deploy def __init__(): a: int128 = 1 diff --git a/tests/functional/codegen/features/decorators/test_private.py b/tests/functional/codegen/features/decorators/test_private.py index 39ea1bb9ae..193112f02b 100644 --- a/tests/functional/codegen/features/decorators/test_private.py +++ b/tests/functional/codegen/features/decorators/test_private.py @@ -120,7 +120,7 @@ def test_private_bytes(get_contract_with_gas_estimation): private_test_code = """ greeting: public(Bytes[100]) -@external +@deploy def __init__(): self.greeting = b"Hello " @@ -143,7 +143,7 @@ def test_private_statement(get_contract_with_gas_estimation): private_test_code = """ greeting: public(Bytes[20]) -@external +@deploy def __init__(): self.greeting = b"Hello " diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index 36252701c4..e1bd8f313d 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -3,6 +3,7 @@ import pytest +from vyper.compiler import compile_code from vyper.exceptions import ( ArgumentException, ImmutableViolation, @@ -841,6 +842,59 @@ def foo(): ] +# TODO: move these to tests/functional/syntax @pytest.mark.parametrize("code,err", BAD_CODE, ids=bad_code_names) def test_bad_code(assert_compile_failed, get_contract, code, err): - assert_compile_failed(lambda: get_contract(code), err) + with pytest.raises(err): + compile_code(code) + + +def test_iterator_modification_module_attribute(make_input_bundle): + # test modifying iterator via attribute + lib1 = """ +queue: DynArray[uint256, 5] + """ + main = """ +import lib1 + +initializes: lib1 + +@external +def foo(): + for i: uint256 in lib1.queue: + lib1.queue.pop() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot modify loop variable `queue`" + + +def test_iterator_modification_module_function_call(make_input_bundle): + lib1 = """ +queue: DynArray[uint256, 5] + +@internal +def popqueue(): + self.queue.pop() + """ + main = """ +import lib1 + +initializes: lib1 + +@external +def foo(): + for i: uint256 in lib1.queue: + lib1.popqueue() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot modify loop variable `queue`" diff --git a/tests/functional/codegen/features/iteration/test_range_in.py b/tests/functional/codegen/features/iteration/test_range_in.py index 7540049778..f381f60b35 100644 --- a/tests/functional/codegen/features/iteration/test_range_in.py +++ b/tests/functional/codegen/features/iteration/test_range_in.py @@ -115,7 +115,7 @@ def test_ownership(w3, tx_failed, get_contract_with_gas_estimation): owners: address[2] -@external +@deploy def __init__(): self.owners[0] = msg.sender diff --git a/tests/functional/codegen/features/test_bytes_map_keys.py b/tests/functional/codegen/features/test_bytes_map_keys.py index 4913182d52..22df767f02 100644 --- a/tests/functional/codegen/features/test_bytes_map_keys.py +++ b/tests/functional/codegen/features/test_bytes_map_keys.py @@ -80,7 +80,7 @@ def test_extended_bytes_key_from_storage(get_contract): code = """ a: HashMap[Bytes[100000], int128] -@external +@deploy def __init__(): self.a[b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"] = 1069 @@ -114,7 +114,7 @@ def test_struct_bytes_key_memory(get_contract): a: HashMap[Bytes[100000], int128] -@external +@deploy def __init__(): self.a[b"hello"] = 1069 self.a[b"potato"] = 31337 @@ -145,7 +145,7 @@ def test_struct_bytes_key_storage(get_contract): a: HashMap[Bytes[100000], int128] b: Foo -@external +@deploy def __init__(): self.a[b"hello"] = 1069 self.a[b"potato"] = 31337 @@ -172,7 +172,7 @@ def test_bytes_key_storage(get_contract): a: HashMap[Bytes[100000], int128] b: Bytes[5] -@external +@deploy def __init__(): self.a[b"hello"] = 1069 self.b = b"hello" @@ -193,7 +193,7 @@ def test_bytes_key_calldata(get_contract): a: HashMap[Bytes[100000], int128] -@external +@deploy def __init__(): self.a[b"hello"] = 1069 @@ -215,7 +215,7 @@ def test_struct_bytes_hashmap_as_key_in_other_hashmap(get_contract): bar: public(HashMap[uint256, Thing]) foo: public(HashMap[Bytes[64], uint256]) -@external +@deploy def __init__(): self.foo[b"hello"] = 31337 self.bar[12] = Thing({name: b"hello"}) diff --git a/tests/functional/codegen/features/test_clampers.py b/tests/functional/codegen/features/test_clampers.py index 6db8570fc7..c028805c6a 100644 --- a/tests/functional/codegen/features/test_clampers.py +++ b/tests/functional/codegen/features/test_clampers.py @@ -67,7 +67,7 @@ def test_bytes_clamper_on_init(tx_failed, get_contract_with_gas_estimation): clamper_test_code = """ foo: Bytes[3] -@external +@deploy def __init__(x: Bytes[3]): self.foo = x diff --git a/tests/functional/codegen/features/test_constructor.py b/tests/functional/codegen/features/test_constructor.py index c9dfcfc5df..9146ace8a6 100644 --- a/tests/functional/codegen/features/test_constructor.py +++ b/tests/functional/codegen/features/test_constructor.py @@ -6,7 +6,7 @@ def test_init_argument_test(get_contract_with_gas_estimation): init_argument_test = """ moose: int128 -@external +@deploy def __init__(_moose: int128): self.moose = _moose @@ -26,7 +26,7 @@ def test_constructor_mapping(get_contract_with_gas_estimation): X: constant(bytes4) = 0x01ffc9a7 -@external +@deploy def __init__(): self.foo[X] = True @@ -44,7 +44,7 @@ def test_constructor_advanced_code(get_contract_with_gas_estimation): constructor_advanced_code = """ twox: int128 -@external +@deploy def __init__(x: int128): self.twox = x * 2 @@ -60,7 +60,7 @@ def test_constructor_advanced_code2(get_contract_with_gas_estimation): constructor_advanced_code2 = """ comb: uint256 -@external +@deploy def __init__(x: uint256[2], y: Bytes[3], z: uint256): self.comb = x[0] * 1000 + x[1] * 100 + len(y) * 10 + z @@ -90,7 +90,7 @@ def foo(x: int128) -> int128: def test_large_input_code_2(w3, get_contract_with_gas_estimation): large_input_code_2 = """ -@external +@deploy def __init__(x: int128): y: int128 = x @@ -113,7 +113,7 @@ def test_initialise_array_with_constant_key(get_contract_with_gas_estimation): foo: int16[X] -@external +@deploy def __init__(): self.foo[X-1] = -2 @@ -133,7 +133,7 @@ def test_initialise_dynarray_with_constant_key(get_contract_with_gas_estimation) foo: DynArray[int16, X] -@external +@deploy def __init__(): self.foo = [X - 3, X - 4, X - 5, X - 6] @@ -151,7 +151,7 @@ def test_nested_dynamic_array_constructor_arg(w3, get_contract_with_gas_estimati code = """ foo: uint256 -@external +@deploy def __init__(x: DynArray[DynArray[uint256, 3], 3]): self.foo = x[0][2] + x[1][1] + x[2][0] @@ -167,7 +167,7 @@ def test_nested_dynamic_array_constructor_arg_2(w3, get_contract_with_gas_estima code = """ foo: int128 -@external +@deploy def __init__(x: DynArray[DynArray[DynArray[int128, 3], 3], 3]): self.foo = x[0][1][2] * x[1][1][1] * x[2][1][0] - x[0][0][0] - x[1][1][1] - x[2][2][2] @@ -192,7 +192,7 @@ def test_initialise_nested_dynamic_array(w3, get_contract_with_gas_estimation): code = """ foo: DynArray[DynArray[uint256, 3], 3] -@external +@deploy def __init__(x: uint256, y: uint256, z: uint256): self.foo = [ [x, y, z], @@ -212,7 +212,7 @@ def test_initialise_nested_dynamic_array_2(w3, get_contract_with_gas_estimation) code = """ foo: DynArray[DynArray[DynArray[int128, 3], 3], 3] -@external +@deploy def __init__(x: int128, y: int128, z: int128): self.foo = [ [[x, y, z], [y, z, x], [z, y, x]], diff --git a/tests/functional/codegen/features/test_immutable.py b/tests/functional/codegen/features/test_immutable.py index 47f7fc748e..d0bc47c238 100644 --- a/tests/functional/codegen/features/test_immutable.py +++ b/tests/functional/codegen/features/test_immutable.py @@ -20,7 +20,7 @@ def test_value_storage_retrieval(typ, value, get_contract): code = f""" VALUE: immutable({typ}) -@external +@deploy def __init__(_value: {typ}): VALUE = _value @@ -41,7 +41,7 @@ def test_usage_in_constructor(get_contract, val): a: public(uint256) -@external +@deploy def __init__(_a: uint256): A = _a self.a = A @@ -63,7 +63,7 @@ def test_multiple_immutable_values(get_contract): b: immutable(address) c: immutable(String[64]) -@external +@deploy def __init__(_a: uint256, _b: address, _c: String[64]): a = _a b = _b @@ -89,7 +89,7 @@ def test_struct_immutable(get_contract): my_struct: immutable(MyStruct) -@external +@deploy def __init__(_a: uint256, _b: uint256, _c: address, _d: int256): my_struct = MyStruct({ a: _a, @@ -108,11 +108,34 @@ def get_my_struct() -> MyStruct: assert c.get_my_struct() == values +def test_complex_immutable_modifiable(get_contract): + code = """ +struct MyStruct: + a: uint256 + +my_struct: immutable(MyStruct) + +@deploy +def __init__(a: uint256): + my_struct = MyStruct({a: a}) + + # struct members are modifiable after initialization + my_struct.a += 1 + +@view +@external +def get_my_struct() -> MyStruct: + return my_struct + """ + c = get_contract(code, 1) + assert c.get_my_struct() == (2,) + + def test_list_immutable(get_contract): code = """ my_list: immutable(uint256[3]) -@external +@deploy def __init__(_a: uint256, _b: uint256, _c: uint256): my_list = [_a, _b, _c] @@ -130,7 +153,7 @@ def test_dynarray_immutable(get_contract): code = """ my_list: immutable(DynArray[uint256, 3]) -@external +@deploy def __init__(_a: uint256, _b: uint256, _c: uint256): my_list = [_a, _b, _c] @@ -154,7 +177,7 @@ def test_nested_dynarray_immutable_2(get_contract): code = """ my_list: immutable(DynArray[DynArray[uint256, 3], 3]) -@external +@deploy def __init__(_a: uint256, _b: uint256, _c: uint256): my_list = [[_a, _b, _c], [_b, _a, _c], [_c, _b, _a]] @@ -179,7 +202,7 @@ def test_nested_dynarray_immutable(get_contract): code = """ my_list: immutable(DynArray[DynArray[DynArray[int128, 3], 3], 3]) -@external +@deploy def __init__(x: int128, y: int128, z: int128): my_list = [ [[x, y, z], [y, z, x], [z, y, x]], @@ -227,7 +250,7 @@ def foo() -> uint256: counter: uint256 VALUE: immutable(uint256) -@external +@deploy def __init__(x: uint256): self.counter = x self.foo() @@ -257,7 +280,7 @@ def foo() -> uint256: b: public(uint256) @payable -@external +@deploy def __init__(to_copy: address): c: address = create_copy_of(to_copy) self.b = a @@ -281,7 +304,7 @@ def test_immutables_initialized2(get_contract, get_contract_from_ir): b: public(uint256) @payable -@external +@deploy def __init__(to_copy: address): c: address = create_copy_of(to_copy) self.b = a @@ -299,7 +322,7 @@ def test_internal_functions_called_by_ctor_location(get_contract): d: uint256 x: immutable(uint256) -@external +@deploy def __init__(): self.d = 1 x = 2 @@ -323,7 +346,7 @@ def test_nested_internal_function_immutables(get_contract): d: public(uint256) x: public(immutable(uint256)) -@external +@deploy def __init__(): self.d = 1 x = 2 @@ -348,7 +371,7 @@ def test_immutable_read_ctor_and_runtime(get_contract): d: public(uint256) x: public(immutable(uint256)) -@external +@deploy def __init__(): self.d = 1 x = 2 diff --git a/tests/functional/codegen/features/test_init.py b/tests/functional/codegen/features/test_init.py index fc765f8ab3..84d224f632 100644 --- a/tests/functional/codegen/features/test_init.py +++ b/tests/functional/codegen/features/test_init.py @@ -5,7 +5,7 @@ def test_basic_init_function(get_contract): code = """ val: public(uint256) -@external +@deploy def __init__(a: uint256): self.val = a """ @@ -27,10 +27,12 @@ def __init__(a: uint256): def test_init_calls_internal(get_contract, assert_compile_failed, tx_failed): code = """ foo: public(uint8) + @internal def bar(x: uint256) -> uint8: return convert(x, uint8) * 7 -@external + +@deploy def __init__(a: uint256): self.foo = self.bar(a) @@ -61,7 +63,7 @@ def test_nested_internal_call_from_ctor(get_contract): code = """ x: uint256 -@external +@deploy def __init__(): self.a() diff --git a/tests/functional/codegen/features/test_logging.py b/tests/functional/codegen/features/test_logging.py index 0cb8ad9abc..8b80811d02 100644 --- a/tests/functional/codegen/features/test_logging.py +++ b/tests/functional/codegen/features/test_logging.py @@ -646,7 +646,7 @@ def test_logging_fails_with_over_three_topics(tx_failed, get_contract_with_gas_e arg3: indexed(int128) arg4: indexed(int128) -@external +@deploy def __init__(): log MyLog(1, 2, 3, 4) """ @@ -1033,7 +1033,7 @@ def test_mixed_var_list_packing(get_logs, get_contract_with_gas_estimation): x: int128[4] y: int128[2] -@external +@deploy def __init__(): self.y = [1024, 2048] diff --git a/tests/functional/codegen/features/test_ternary.py b/tests/functional/codegen/features/test_ternary.py index c5480286c8..661fdc86c9 100644 --- a/tests/functional/codegen/features/test_ternary.py +++ b/tests/functional/codegen/features/test_ternary.py @@ -195,7 +195,7 @@ def test_ternary_tuple(get_contract, code, test): def test_ternary_immutable(get_contract, test): code = """ IMM: public(immutable(uint256)) -@external +@deploy def __init__(test: bool): IMM = 1 if test else 2 """ diff --git a/tests/functional/codegen/integration/test_crowdfund.py b/tests/functional/codegen/integration/test_crowdfund.py index 891ed5aebe..1a8b3f7e9f 100644 --- a/tests/functional/codegen/integration/test_crowdfund.py +++ b/tests/functional/codegen/integration/test_crowdfund.py @@ -13,7 +13,7 @@ def test_crowdfund(w3, tester, get_contract_with_gas_estimation_for_constants): refundIndex: int128 timelimit: public(uint256) -@external +@deploy def __init__(_beneficiary: address, _goal: uint256, _timelimit: uint256): self.beneficiary = _beneficiary self.deadline = block.timestamp + _timelimit @@ -109,7 +109,7 @@ def test_crowdfund2(w3, tester, get_contract_with_gas_estimation_for_constants): refundIndex: int128 timelimit: public(uint256) -@external +@deploy def __init__(_beneficiary: address, _goal: uint256, _timelimit: uint256): self.beneficiary = _beneficiary self.deadline = block.timestamp + _timelimit diff --git a/tests/functional/codegen/integration/test_escrow.py b/tests/functional/codegen/integration/test_escrow.py index 70e7cb4594..f86b4aa516 100644 --- a/tests/functional/codegen/integration/test_escrow.py +++ b/tests/functional/codegen/integration/test_escrow.py @@ -41,7 +41,7 @@ def test_arbitration_code_with_init(w3, tx_failed, get_contract_with_gas_estimat seller: address arbitrator: address -@external +@deploy @payable def __init__(_seller: address, _arbitrator: address): if self.buyer == empty(address): diff --git a/tests/functional/codegen/modules/test_module_constants.py b/tests/functional/codegen/modules/test_module_constants.py index aafbb69252..ebfefb4546 100644 --- a/tests/functional/codegen/modules/test_module_constants.py +++ b/tests/functional/codegen/modules/test_module_constants.py @@ -76,3 +76,23 @@ def foo(ix: uint256) -> uint256: assert c.foo(2) == 3 with tx_failed(): c.foo(3) + + +def test_module_constant_builtin(make_input_bundle, get_contract): + # test empty builtin, which is not (currently) foldable 2024-02-06 + mod1 = """ +X: constant(uint256) = empty(uint256) + """ + contract = """ +import mod1 + +@external +def foo() -> uint256: + return mod1.X + """ + + input_bundle = make_input_bundle({"mod1.vy": mod1}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.foo() == 0 diff --git a/tests/functional/codegen/modules/test_module_variables.py b/tests/functional/codegen/modules/test_module_variables.py new file mode 100644 index 0000000000..6bb1f9072c --- /dev/null +++ b/tests/functional/codegen/modules/test_module_variables.py @@ -0,0 +1,318 @@ +def test_simple_import(get_contract, make_input_bundle): + lib1 = """ +counter: uint256 + +@internal +def increment_counter(): + self.counter += 1 + """ + + contract = """ +import lib + +initializes: lib + +@external +def increment_counter(): + lib.increment_counter() + +@external +def get_counter() -> uint256: + return lib.counter + """ + + input_bundle = make_input_bundle({"lib.vy": lib1}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.get_counter() == 0 + c.increment_counter(transact={}) + assert c.get_counter() == 1 + + +def test_import_namespace(get_contract, make_input_bundle): + # test what happens when things in current and imported modules share names + lib = """ +counter: uint256 + +@internal +def increment_counter(): + self.counter += 1 + """ + + contract = """ +import library as lib + +counter: uint256 + +initializes: lib + +@external +def increment_counter(): + self.counter += 1 + +@external +def increment_lib_counter(): + lib.increment_counter() + +@external +def increment_lib_counter2(): + # modify lib.counter directly + lib.counter += 5 + +@external +def get_counter() -> uint256: + return self.counter + +@external +def get_lib_counter() -> uint256: + return lib.counter + """ + + input_bundle = make_input_bundle({"library.vy": lib}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.get_counter() == c.get_lib_counter() == 0 + + c.increment_counter(transact={}) + assert c.get_counter() == 1 + assert c.get_lib_counter() == 0 + + c.increment_lib_counter(transact={}) + assert c.get_lib_counter() == 1 + assert c.get_counter() == 1 + + c.increment_lib_counter2(transact={}) + assert c.get_lib_counter() == 6 + assert c.get_counter() == 1 + + +def test_init_function_side_effects(get_contract, make_input_bundle): + lib = """ +counter: uint256 + +MY_IMMUTABLE: immutable(uint256) + +@deploy +def __init__(initial_value: uint256): + self.counter = initial_value + MY_IMMUTABLE = initial_value * 2 + +@internal +def increment_counter(): + self.counter += 1 + """ + + contract = """ +import library as lib + +counter: public(uint256) + +MY_IMMUTABLE: public(immutable(uint256)) + +initializes: lib + +@deploy +def __init__(): + self.counter = 1 + MY_IMMUTABLE = 3 + lib.__init__(5) + +@external +def get_lib_counter() -> uint256: + return lib.counter + +@external +def get_lib_immutable() -> uint256: + return lib.MY_IMMUTABLE + """ + + input_bundle = make_input_bundle({"library.vy": lib}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.counter() == 1 + assert c.MY_IMMUTABLE() == 3 + assert c.get_lib_counter() == 5 + assert c.get_lib_immutable() == 10 + + +def test_indirect_variable_uses(get_contract, make_input_bundle): + lib1 = """ +counter: uint256 + +MY_IMMUTABLE: immutable(uint256) + +@deploy +def __init__(initial_value: uint256): + self.counter = initial_value + MY_IMMUTABLE = initial_value * 2 + +@internal +def increment_counter(): + self.counter += 1 + """ + lib2 = """ +import lib1 + +uses: lib1 + +@internal +def get_lib1_counter() -> uint256: + return lib1.counter + +@internal +def get_lib1_my_immutable() -> uint256: + return lib1.MY_IMMUTABLE + """ + + contract = """ +import lib1 +import lib2 + +initializes: lib1 +initializes: lib2[lib1 := lib1] + +@deploy +def __init__(): + lib1.__init__(5) + +@external +def get_storage_via_lib1() -> uint256: + return lib1.counter + +@external +def get_immutable_via_lib1() -> uint256: + return lib1.MY_IMMUTABLE + +@external +def get_storage_via_lib2() -> uint256: + return lib2.get_lib1_counter() + +@external +def get_immutable_via_lib2() -> uint256: + return lib2.get_lib1_my_immutable() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.get_storage_via_lib1() == c.get_storage_via_lib2() == 5 + assert c.get_immutable_via_lib1() == c.get_immutable_via_lib2() == 10 + + +def test_uses_already_initialized(get_contract, make_input_bundle): + lib1 = """ +counter: uint256 +MY_IMMUTABLE: immutable(uint256) + +@deploy +def __init__(initial_value: uint256): + self.counter = initial_value * 2 + MY_IMMUTABLE = initial_value * 3 + +@internal +def increment_counter(): + self.counter += 1 + """ + lib2 = """ +import lib1 + +initializes: lib1 + +@deploy +def __init__(): + lib1.__init__(5) + +@internal +def get_lib1_counter() -> uint256: + return lib1.counter + +@internal +def get_lib1_my_immutable() -> uint256: + return lib1.MY_IMMUTABLE + """ + + contract = """ +import lib1 +import lib2 + +uses: lib1 +initializes: lib2 + +@deploy +def __init__(): + lib2.__init__() + +@external +def get_storage_via_lib1() -> uint256: + return lib1.counter + +@external +def get_immutable_via_lib1() -> uint256: + return lib1.MY_IMMUTABLE + +@external +def get_storage_via_lib2() -> uint256: + return lib2.get_lib1_counter() + +@external +def get_immutable_via_lib2() -> uint256: + return lib2.get_lib1_my_immutable() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.get_storage_via_lib1() == c.get_storage_via_lib2() == 10 + assert c.get_immutable_via_lib1() == c.get_immutable_via_lib2() == 15 + + +def test_import_complex_types(get_contract, make_input_bundle): + lib1 = """ +an_array: uint256[3] +a_hashmap: HashMap[address, HashMap[uint256, uint256]] + +@internal +def set_array_value(ix: uint256, new_value: uint256): + self.an_array[ix] = new_value + +@internal +def set_hashmap_value(ix0: address, ix1: uint256, new_value: uint256): + self.a_hashmap[ix0][ix1] = new_value + """ + + contract = """ +import lib + +initializes: lib + +@external +def do_things(): + lib.set_array_value(1, 5) + lib.set_hashmap_value(msg.sender, 6, 100) + +@external +def get_array_value(ix: uint256) -> uint256: + return lib.an_array[ix] + +@external +def get_hashmap_value(ix: uint256) -> uint256: + return lib.a_hashmap[msg.sender][ix] + """ + + input_bundle = make_input_bundle({"lib.vy": lib1}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.get_array_value(0) == 0 + assert c.get_hashmap_value(0) == 0 + c.do_things(transact={}) + + assert c.get_array_value(0) == 0 + assert c.get_hashmap_value(0) == 0 + assert c.get_array_value(1) == 5 + assert c.get_hashmap_value(6) == 100 diff --git a/tests/functional/codegen/storage_variables/test_getters.py b/tests/functional/codegen/storage_variables/test_getters.py index a2d9c6d0bb..9e72bed075 100644 --- a/tests/functional/codegen/storage_variables/test_getters.py +++ b/tests/functional/codegen/storage_variables/test_getters.py @@ -41,7 +41,7 @@ def foo(): nonpayable f: public(constant(uint256[2])) = [3, 7] g: public(constant(V)) = V(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF) -@external +@deploy def __init__(): self.x = as_wei_value(7, "wei") self.y[1] = 9 @@ -87,7 +87,7 @@ def test_getter_mutability(get_contract): nyoro: public(constant(uint256)) = 2 kune: public(immutable(uint256)) -@external +@deploy def __init__(): kune = 2 """ diff --git a/tests/functional/codegen/storage_variables/test_storage_variable.py b/tests/functional/codegen/storage_variables/test_storage_variable.py index 4636fa77e0..7a22d35e4b 100644 --- a/tests/functional/codegen/storage_variables/test_storage_variable.py +++ b/tests/functional/codegen/storage_variables/test_storage_variable.py @@ -10,7 +10,7 @@ def test_permanent_variables_test(get_contract_with_gas_estimation): b: int128 var: Var -@external +@deploy def __init__(a: int128, b: int128): self.var.a = a self.var.b = b diff --git a/tests/functional/codegen/test_interfaces.py b/tests/functional/codegen/test_interfaces.py index 3344ff113b..85efe904a0 100644 --- a/tests/functional/codegen/test_interfaces.py +++ b/tests/functional/codegen/test_interfaces.py @@ -305,7 +305,7 @@ def test() -> uint256: view token_address: IToken -@external +@deploy def __init__(_token_address: address): self.token_address = IToken(_token_address) @@ -388,7 +388,7 @@ def transfer(to: address, amount: uint256) -> bool: token_address: ERC20 -@external +@deploy def __init__(_token_address: address): self.token_address = ERC20(_token_address) @@ -445,7 +445,7 @@ def should_fail() -> {typ}: view foo: BadContract -@external +@deploy def __init__(addr: BadContract): self.foo = addr @@ -501,7 +501,7 @@ def should_fail() -> Bytes[2]: view foo: BadContract -@external +@deploy def __init__(addr: BadContract): self.foo = addr @@ -551,7 +551,7 @@ def foo(x: BadJSONInterface) -> Bytes[2]: foo: BadJSONInterface -@external +@deploy def __init__(addr: BadJSONInterface): self.foo = addr @@ -667,7 +667,7 @@ def foo() -> uint256: view bar_contract: Bar -@external +@deploy def __init__(): self.bar_contract = Bar(self) diff --git a/tests/functional/codegen/types/test_bytes.py b/tests/functional/codegen/types/test_bytes.py index 325f9d7923..99e5835f6e 100644 --- a/tests/functional/codegen/types/test_bytes.py +++ b/tests/functional/codegen/types/test_bytes.py @@ -51,7 +51,7 @@ def test_test_bytes3(get_contract_with_gas_estimation): maa: Bytes[60] y: int128 -@external +@deploy def __init__(): self.x = 27 self.y = 37 diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index d3d945740b..fc3223caaf 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -1665,7 +1665,7 @@ def ix(i: uint256) -> decimal: def test_public_dynarray(get_contract): code = """ my_list: public(DynArray[uint256, 5]) -@external +@deploy def __init__(): self.my_list = [1,2,3] """ @@ -1678,7 +1678,7 @@ def __init__(): def test_nested_public_dynarray(get_contract): code = """ my_list: public(DynArray[DynArray[uint256, 5], 5]) -@external +@deploy def __init__(): self.my_list = [[1,2,3]] """ diff --git a/tests/functional/codegen/types/test_flag.py b/tests/functional/codegen/types/test_flag.py index 5da6d57558..dd9c867a96 100644 --- a/tests/functional/codegen/types/test_flag.py +++ b/tests/functional/codegen/types/test_flag.py @@ -160,7 +160,7 @@ def test_augassign_storage(get_contract, w3, tx_failed): roles: public(HashMap[address, Roles]) -@external +@deploy def __init__(): self.roles[msg.sender] = Roles.ADMIN diff --git a/tests/functional/codegen/types/test_string.py b/tests/functional/codegen/types/test_string.py index 9d50f8df38..9d596eda32 100644 --- a/tests/functional/codegen/types/test_string.py +++ b/tests/functional/codegen/types/test_string.py @@ -90,7 +90,7 @@ def test_private_string(get_contract_with_gas_estimation): private_test_code = """ greeting: public(String[100]) -@external +@deploy def __init__(): self.greeting = "Hello " diff --git a/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py b/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py index e21a113f61..f6eb3966d4 100644 --- a/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py +++ b/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py @@ -118,7 +118,7 @@ def unlocked() -> bool: view purchase_contract: PurchaseContract -@external +@deploy def __init__(_purchase_contract: address): self.purchase_contract = PurchaseContract(_purchase_contract) diff --git a/tests/functional/syntax/exceptions/test_call_violation.py b/tests/functional/syntax/exceptions/test_call_violation.py index d310a2b42a..d96df07e74 100644 --- a/tests/functional/syntax/exceptions/test_call_violation.py +++ b/tests/functional/syntax/exceptions/test_call_violation.py @@ -27,6 +27,15 @@ def goo(): def foo(): self.goo() """, + """ +@deploy +def __init__(): + pass + +@internal +def foo(): + self.__init__() + """, ] diff --git a/tests/functional/syntax/exceptions/test_constancy_exception.py b/tests/functional/syntax/exceptions/test_constancy_exception.py index 7adf9538c7..6bfb8fee57 100644 --- a/tests/functional/syntax/exceptions/test_constancy_exception.py +++ b/tests/functional/syntax/exceptions/test_constancy_exception.py @@ -78,7 +78,7 @@ def foo(): """ f:int128 -@external +@internal def a (x:int128): self.f = 100 @@ -86,6 +86,63 @@ def a (x:int128): @external def b(): self.a(10)""", + """ +interface A: + def bar() -> uint16: view +@external +@pure +def test(to:address): + a:A = A(to) + x:uint16 = a.bar() + """, + """ +interface A: + def bar() -> uint16: view +@external +@pure +def test(to:address): + a:A = A(to) + a.bar() + """, + """ +interface A: + def bar() -> uint16: nonpayable +@external +@view +def test(to:address): + a:A = A(to) + x:uint16 = a.bar() + """, + """ +interface A: + def bar() -> uint16: nonpayable +@external +@view +def test(to:address): + a:A = A(to) + a.bar() + """, + """ +a:DynArray[uint16,3] +@deploy +def __init__(): + self.a = [1,2,3] +@view +@external +def bar()->DynArray[uint16,3]: + x:uint16 = self.a.pop() + return self.a # return [1,2] + """, + """ +from ethereum.ercs import ERC20 + +token: ERC20 + +@external +@view +def topup(amount: uint256): + assert self.token.transferFrom(msg.sender, self, amount) + """, ], ) def test_statefulness_violations(bad_code): diff --git a/tests/functional/syntax/exceptions/test_function_declaration_exception.py b/tests/functional/syntax/exceptions/test_function_declaration_exception.py index 3fe23e0ec7..878c7f3e29 100644 --- a/tests/functional/syntax/exceptions/test_function_declaration_exception.py +++ b/tests/functional/syntax/exceptions/test_function_declaration_exception.py @@ -34,17 +34,17 @@ def test_func() -> int128: return (1, 2) """, """ -@external +@deploy def __init__(a: int128 = 12): pass """, """ -@external +@deploy def __init__() -> uint256: return 1 """, """ -@external +@deploy def __init__() -> bool: pass """, @@ -58,7 +58,7 @@ def __init__(): """ a: immutable(uint256) -@external +@deploy @pure def __init__(): a = 1 @@ -66,7 +66,7 @@ def __init__(): """ a: immutable(uint256) -@external +@deploy @view def __init__(): a = 1 diff --git a/tests/functional/syntax/exceptions/test_instantiation_exception.py b/tests/functional/syntax/exceptions/test_instantiation_exception.py index 0d641f154a..4dd0bf6e02 100644 --- a/tests/functional/syntax/exceptions/test_instantiation_exception.py +++ b/tests/functional/syntax/exceptions/test_instantiation_exception.py @@ -69,7 +69,7 @@ def foo(): """ b: immutable(HashMap[uint256, uint256]) -@external +@deploy def __init__(): b = empty(HashMap[uint256, uint256]) """, diff --git a/tests/functional/syntax/exceptions/test_invalid_reference.py b/tests/functional/syntax/exceptions/test_invalid_reference.py index fe315e5cbf..7519d1406e 100644 --- a/tests/functional/syntax/exceptions/test_invalid_reference.py +++ b/tests/functional/syntax/exceptions/test_invalid_reference.py @@ -47,7 +47,7 @@ def foo(): """ a: public(immutable(uint256)) -@external +@deploy def __init__(): a = 123 diff --git a/tests/functional/syntax/exceptions/test_structure_exception.py b/tests/functional/syntax/exceptions/test_structure_exception.py index c6d733fc90..afc7a35012 100644 --- a/tests/functional/syntax/exceptions/test_structure_exception.py +++ b/tests/functional/syntax/exceptions/test_structure_exception.py @@ -94,7 +94,7 @@ def foo(): a: immutable(uint256) n: public(HashMap[uint256, bool][a]) -@external +@deploy def __init__(): a = 3 """, @@ -105,14 +105,14 @@ def __init__(): m1: HashMap[uint8, uint8] m2: HashMap[uint8, uint8] -@external +@deploy def __init__(): self.m1 = self.m2 """, """ m1: HashMap[uint8, uint8] -@external +@deploy def __init__(): self.m1 = 234 """, diff --git a/tests/functional/syntax/exceptions/test_vyper_exception_pos.py b/tests/functional/syntax/exceptions/test_vyper_exception_pos.py index a261cb0a11..9e0767cb83 100644 --- a/tests/functional/syntax/exceptions/test_vyper_exception_pos.py +++ b/tests/functional/syntax/exceptions/test_vyper_exception_pos.py @@ -22,7 +22,7 @@ def test_multiple_exceptions(get_contract, assert_compile_failed): foo: immutable(uint256) bar: immutable(uint256) -@external +@deploy def __init__(): self.foo = 1 # SyntaxException self.bar = 2 # SyntaxException diff --git a/tests/functional/syntax/modules/test_deploy_visibility.py b/tests/functional/syntax/modules/test_deploy_visibility.py new file mode 100644 index 0000000000..f51bf9575b --- /dev/null +++ b/tests/functional/syntax/modules/test_deploy_visibility.py @@ -0,0 +1,27 @@ +import pytest + +from vyper.compiler import compile_code +from vyper.exceptions import CallViolation + + +def test_call_deploy_from_external(make_input_bundle): + lib1 = """ +@deploy +def __init__(): + pass + """ + + main = """ +import lib1 + +@external +def foo(): + lib1.__init__() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(CallViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value.message == "Cannot call an @deploy function from an @external function!" diff --git a/tests/functional/syntax/modules/test_implements.py b/tests/functional/syntax/modules/test_implements.py new file mode 100644 index 0000000000..c292e198d9 --- /dev/null +++ b/tests/functional/syntax/modules/test_implements.py @@ -0,0 +1,51 @@ +from vyper.compiler import compile_code + + +def test_implements_from_vyi(make_input_bundle): + vyi = """ +@external +def foo(): + ... + """ + lib1 = """ +import some_interface + """ + main = """ +import lib1 + +implements: lib1.some_interface + +@external +def foo(): # implementation + pass + """ + input_bundle = make_input_bundle({"some_interface.vyi": vyi, "lib1.vy": lib1}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_implements_from_vyi2(make_input_bundle): + # test implements via nested imported vyi file + vyi = """ +@external +def foo(): + ... + """ + lib1 = """ +import some_interface + """ + lib2 = """ +import lib1 + """ + main = """ +import lib2 + +implements: lib2.lib1.some_interface + +@external +def foo(): # implementation + pass + """ + input_bundle = make_input_bundle({"some_interface.vyi": vyi, "lib1.vy": lib1, "lib2.vy": lib2}) + + assert compile_code(main, input_bundle=input_bundle) is not None diff --git a/tests/functional/syntax/modules/test_initializers.py b/tests/functional/syntax/modules/test_initializers.py new file mode 100644 index 0000000000..0412e83c7d --- /dev/null +++ b/tests/functional/syntax/modules/test_initializers.py @@ -0,0 +1,1181 @@ +""" +tests for the uses/initializes checker +main properties to test: +- state usage -- if a module uses state, it must `used` or `initialized` +- conversely, if a module does not touch state, it should not be `used` +- global initializer check: each used module is `initialized` exactly once +""" + +import pytest + +from vyper.compiler import compile_code +from vyper.exceptions import ( + BorrowException, + ImmutableViolation, + InitializerException, + StructureException, + UndeclaredDefinition, +) + + +def test_initialize_uses(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + pass + """ + lib2 = """ +import lib1 + +uses: lib1 + +counter: uint256 + +@deploy +def __init__(): + pass + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib2 +import lib1 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + +@deploy +def __init__(): + lib1.__init__() + lib2.__init__() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_initialize_multiple_uses(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + pass + """ + lib2 = """ +totalSupply: uint256 + """ + lib3 = """ +import lib1 +import lib2 + +# multiple uses on one line +uses: ( + lib1, + lib2 +) + +counter: uint256 + +@deploy +def __init__(): + pass + +@internal +def foo(): + x: uint256 = lib2.totalSupply + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 +import lib3 + +initializes: lib1 +initializes: lib2 +initializes: lib3[ + lib1 := lib1, + lib2 := lib2 +] + +@deploy +def __init__(): + lib1.__init__() + lib3.__init__() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2, "lib3.vy": lib3}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_initialize_multi_line_uses(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + pass + """ + lib2 = """ +totalSupply: uint256 + """ + lib3 = """ +import lib1 +import lib2 + +uses: lib1 +uses: lib2 + +counter: uint256 + +@deploy +def __init__(): + pass + +@internal +def foo(): + x: uint256 = lib2.totalSupply + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 +import lib3 + +initializes: lib1 +initializes: lib2 +initializes: lib3[ + lib1 := lib1, + lib2 := lib2 +] + +@deploy +def __init__(): + lib1.__init__() + lib3.__init__() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2, "lib3.vy": lib3}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_initialize_uses_attribute(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + pass + """ + lib2 = """ +import lib1 + +uses: lib1 + +counter: uint256 + +@deploy +def __init__(): + pass + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + +@deploy +def __init__(): + lib2.__init__() + # demonstrate we can call lib1.__init__ through lib2.lib1 + # (not sure this should be allowed, really. + lib2.lib1.__init__() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_initializes_without_init_function(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +uses: lib1 + +counter: uint256 + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + +@deploy +def __init__(): + pass + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_imported_as_different_names(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 as m + +uses: m + +counter: uint256 + +@internal +def foo(): + m.counter += 1 + """ + main = """ +import lib1 as some_module +import lib2 + +initializes: lib2[m := some_module] +initializes: some_module + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_initializer_list_module_mismatch(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +something: uint256 + """ + lib3 = """ +import lib1 + +uses: lib1 + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 +import lib3 + +initializes: lib1 +initializes: lib3[lib1 := lib2] # typo -- should be [lib1 := lib1] + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2, "lib3.vy": lib3}) + + with pytest.raises(StructureException) as e: + assert compile_code(main, input_bundle=input_bundle) is not None + + assert e.value._message == "lib1 is not lib2!" + + +def test_imported_as_different_names_error(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 as m + +uses: m + +counter: uint256 + +@internal +def foo(): + m.counter += 1 + """ + main = """ +import lib1 +import lib2 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(UndeclaredDefinition) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "unknown module `lib1`" + assert e.value._hint == "did you mean `m := lib1`?" + + +def test_global_initializer_constraint(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +uses: lib1 + +counter: uint256 + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 + +initializes: lib2[lib1 := lib1] +# forgot to initialize lib1! + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(InitializerException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "module `lib1.vy` is used but never initialized!" + assert e.value._hint == "add `initializes: lib1` to the top level of your main contract" + + +def test_initializer_no_references(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +uses: lib1 + +counter: uint256 + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 + +initializes: lib2 +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(InitializerException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "`lib2` uses `lib1`, but it is not initialized with `lib1`" + assert e.value._hint == "add `lib1` to its initializer list" + + +def test_missing_uses(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! + +counter: uint256 + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 + +initializes: lib2 +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_for_read(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! + +counter: uint256 + +@internal +def foo() -> uint256: + return lib1.counter + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_for_read_immutable(make_input_bundle): + lib1 = """ +MY_IMMUTABLE: immutable(uint256) + +@deploy +def __init__(): + MY_IMMUTABLE = 7 + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! + +counter: uint256 + +@internal +def foo() -> uint256: + return lib1.MY_IMMUTABLE + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_for_read_inside_call(make_input_bundle): + lib1 = """ +MY_IMMUTABLE: immutable(uint256) + +@deploy +def __init__(): + MY_IMMUTABLE = 9 + +@internal +def get_counter() -> uint256: + return MY_IMMUTABLE + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! + +counter: uint256 + +@internal +def foo() -> uint256: + return lib1.get_counter() + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_for_hashmap(make_input_bundle): + lib1 = """ +counter: HashMap[uint256, HashMap[uint256, uint256]] + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! + +@internal +def foo() -> uint256: + return lib1.counter[1][2] + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_for_tuple(make_input_bundle): + lib1 = """ +counter: HashMap[uint256, HashMap[uint256, uint256]] + """ + lib2 = """ +import lib1 + +interface Foo: + def foo() -> (uint256, uint256): nonpayable + +something: uint256 + +# forgot `uses: lib1`! + +@internal +def foo() -> uint256: + lib1.counter[1][2], self.something = Foo(msg.sender).foo() + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 +initializes: lib2 + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_for_tuple_function_call(make_input_bundle): + lib1 = """ +counter: HashMap[uint256, HashMap[uint256, uint256]] + +something: uint256 + +interface Foo: + def foo() -> (uint256, uint256): nonpayable + +@internal +def write_tuple(): + self.counter[1][2], self.something = Foo(msg.sender).foo() + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! +@internal +def foo(): + lib1.write_tuple() + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 +initializes: lib2 + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_function_call(make_input_bundle): + # test missing uses through function call + lib1 = """ +counter: uint256 + +@internal +def update_counter(new_value: uint256): + self.counter = new_value + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! + +counter: uint256 + +@internal +def foo(): + lib1.update_counter(lib1.counter + 1) + """ + main = """ +import lib1 +import lib2 + +initializes: lib2 +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_nested_attribute(make_input_bundle): + # test missing uses through nested attribute access + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +counter: uint256 + +@internal +def foo(): + pass + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +# did not `use` or `initialize` lib2! + +@external +def foo(new_value: uint256): + # cannot access lib1 state through lib2 + lib2.lib1.counter = new_value + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib2` state!" + + expected_hint = "add `uses: lib2` or `initializes: lib2` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_subscript(make_input_bundle): + # test missing uses through nested subscript/attribute access + lib1 = """ +struct Foo: + array: uint256[5] + +foos: Foo[5] + """ + lib2 = """ +import lib1 + +counter: uint256 + +@internal +def foo(): + pass + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +# did not `use` or `initialize` lib2! + +@external +def foo(new_value: uint256): + # cannot access lib1 state through lib2 + lib2.lib1.foos[0].array[1] = new_value + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib2` state!" + + expected_hint = "add `uses: lib2` or `initializes: lib2` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_nested_attribute_function_call(make_input_bundle): + # test missing uses through nested attribute access + lib1 = """ +counter: uint256 + +@internal +def update_counter(new_value: uint256): + self.counter = new_value + """ + lib2 = """ +import lib1 + +counter: uint256 + +@internal +def foo(): + pass + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +# did not `use` or `initialize` lib2! + +@external +def foo(new_value: uint256): + # cannot access lib1 state through lib2 + lib2.lib1.update_counter(new_value) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib2` state!" + + expected_hint = "add `uses: lib2` or `initializes: lib2` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_uses_skip_import(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +@internal +def foo(): + pass + """ + main = """ +import lib1 +import lib2 + +initializes: lib2 + +@external +def foo(new_value: uint256): + # can access lib1 state through lib2? + lib2.lib1.counter = new_value + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_invalid_uses(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +uses: lib1 # not necessary! + +counter: uint256 + +@internal +def foo(): + pass + """ + main = """ +import lib1 +import lib2 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(BorrowException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "`lib1` is declared as used, but it is not actually used in lib2.vy!" + assert e.value._hint == "delete `uses: lib1`" + + +def test_invalid_uses2(make_input_bundle): + # test a more complicated invalid uses + lib1 = """ +counter: uint256 + +@internal +def foo(addr: address): + # sends value -- modifies ethereum state + to_send_value: uint256 = 100 + raw_call(addr, b"someFunction()", value=to_send_value) + """ + lib2 = """ +import lib1 + +uses: lib1 # not necessary! + +counter: uint256 + +@internal +def foo(): + lib1.foo(msg.sender) + """ + main = """ +import lib1 +import lib2 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + +@external +def foo(): + lib2.foo() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(BorrowException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "`lib1` is declared as used, but it is not actually used in lib2.vy!" + assert e.value._hint == "delete `uses: lib1`" + + +def test_initializes_uses_conflict(make_input_bundle): + lib1 = """ +counter: uint256 + """ + main = """ +import lib1 + +initializes: lib1 +uses: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(StructureException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "ownership already set to `initializes`" + + +def test_uses_initializes_conflict(make_input_bundle): + lib1 = """ +counter: uint256 + """ + main = """ +import lib1 + +uses: lib1 +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(StructureException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "ownership already set to `uses`" + + +def test_uses_twice(make_input_bundle): + lib1 = """ +counter: uint256 + """ + main = """ +import lib1 + +uses: lib1 + +random_variable: constant(uint256) = 3 + +uses: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(StructureException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "ownership already set to `uses`" + + +def test_initializes_twice(make_input_bundle): + lib1 = """ +counter: uint256 + """ + main = """ +import lib1 + +initializes: lib1 + +random_variable: constant(uint256) = 3 + +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(StructureException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "ownership already set to `initializes`" + + +def test_no_initialize_unused_module(make_input_bundle): + lib1 = """ +counter: uint256 + +@internal +def set_counter(new_value: uint256): + self.counter = new_value + +@internal +@pure +def add(x: uint256, y: uint256) -> uint256: + return x + y + """ + main = """ +import lib1 + +# not needed: `initializes: lib1` + +@external +def do_add(x: uint256, y: uint256) -> uint256: + return lib1.add(x, y) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_no_initialize_unused_module2(make_input_bundle): + # slightly more complicated + lib1 = """ +counter: uint256 + +@internal +def set_counter(new_value: uint256): + self.counter = new_value + +@internal +@pure +def add(x: uint256, y: uint256) -> uint256: + return x + y + """ + lib2 = """ +import lib1 + +@internal +@pure +def addmul(x: uint256, y: uint256, z: uint256) -> uint256: + return lib1.add(x, y) * z + """ + main = """ +import lib1 +import lib2 + +@external +def do_addmul(x: uint256, y: uint256) -> uint256: + return lib2.addmul(x, y, 5) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_init_uninitialized_function(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + pass + """ + main = """ +import lib1 + +# missing `initializes: lib1`! + +@deploy +def __init__(): + lib1.__init__() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(InitializerException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "tried to initialize `lib1`, but it is not in initializer list!" + assert e.value._hint == "add `initializes: lib1` as a top-level statement to your contract" + + +def test_init_uninitialized_function2(make_input_bundle): + # test that we can't call module.__init__() even when we call `uses` + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + pass + """ + main = """ +import lib1 + +uses: lib1 +# missing `initializes: lib1`! + +@deploy +def __init__(): + lib1.__init__() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(InitializerException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "tried to initialize `lib1`, but it is not in initializer list!" + assert e.value._hint == "add `initializes: lib1` as a top-level statement to your contract" + + +def test_noinit_initialized_function(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + self.counter = 5 + """ + main = """ +import lib1 + +initializes: lib1 + +@deploy +def __init__(): + pass # missing `lib1.__init__()`! + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(InitializerException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "not initialized!" + assert e.value._hint == "add `lib1.__init__()` to your `__init__()` function" + + +def test_noinit_initialized_function2(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + self.counter = 5 + """ + main = """ +import lib1 + +initializes: lib1 + +# missing `lib1.__init__()`! + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(InitializerException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "not initialized!" + assert e.value._hint == "add `lib1.__init__()` to your `__init__()` function" + + +def test_ownership_decl_errors_not_swallowed(make_input_bundle): + lib1 = """ +counter: uint256 + """ + main = """ +import lib1 +# forgot to import lib2 + +uses: (lib1, lib2) # should get UndeclaredDefinition + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(UndeclaredDefinition) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "'lib2' has not been declared. " diff --git a/tests/functional/syntax/test_address_code.py b/tests/functional/syntax/test_address_code.py index fa6ed20117..5873eb5af8 100644 --- a/tests/functional/syntax/test_address_code.py +++ b/tests/functional/syntax/test_address_code.py @@ -165,7 +165,7 @@ def test_address_code_self_success(get_contract, optimize): code = """ code_deployment: public(Bytes[32]) -@external +@deploy def __init__(): self.code_deployment = slice(self.code, 0, 32) @@ -186,7 +186,7 @@ def test_address_code_self_runtime_error_deployment(get_contract): code = """ dummy: public(Bytes[1000000]) -@external +@deploy def __init__(): self.dummy = slice(self.code, 0, 1000000) """ diff --git a/tests/functional/syntax/test_codehash.py b/tests/functional/syntax/test_codehash.py index c2d9a2e274..8aada22da7 100644 --- a/tests/functional/syntax/test_codehash.py +++ b/tests/functional/syntax/test_codehash.py @@ -11,7 +11,7 @@ def test_get_extcodehash(get_contract, evm_version, optimize): code = """ a: address -@external +@deploy def __init__(): self.a = self diff --git a/tests/functional/syntax/test_constants.py b/tests/functional/syntax/test_constants.py index 57922f28e2..63abf24485 100644 --- a/tests/functional/syntax/test_constants.py +++ b/tests/functional/syntax/test_constants.py @@ -94,7 +94,7 @@ VAL: immutable(uint256) VAL: uint256 -@external +@deploy def __init__(): VAL = 1 """, @@ -106,7 +106,7 @@ def __init__(): VAL: uint256 VAL: immutable(uint256) -@external +@deploy def __init__(): VAL = 1 """, diff --git a/tests/functional/syntax/test_immutables.py b/tests/functional/syntax/test_immutables.py index 1027d9fe66..59fb1a69d9 100644 --- a/tests/functional/syntax/test_immutables.py +++ b/tests/functional/syntax/test_immutables.py @@ -8,7 +8,7 @@ """ VALUE: immutable(uint256) -@external +@deploy def __init__(): pass """, @@ -25,7 +25,7 @@ def get_value() -> uint256: """ VALUE: immutable(uint256) = 3 -@external +@deploy def __init__(): pass """, @@ -33,7 +33,7 @@ def __init__(): """ VALUE: immutable(uint256) -@external +@deploy def __init__(): VALUE = 0 @@ -45,7 +45,7 @@ def set_value(_value: uint256): """ VALUE: immutable(uint256) -@external +@deploy def __init__(_value: uint256): VALUE = _value * 3 VALUE = VALUE + 1 @@ -54,7 +54,7 @@ def __init__(_value: uint256): """ VALUE: immutable(public(uint256)) -@external +@deploy def __init__(_value: uint256): VALUE = _value * 3 """, @@ -85,7 +85,7 @@ def test_compilation_simple_usage(typ): code = f""" VALUE: immutable({typ}) -@external +@deploy def __init__(_value: {typ}): VALUE = _value @@ -103,7 +103,7 @@ def get_value() -> {typ}: """ VALUE: immutable(uint256) -@external +@deploy def __init__(_value: uint256): VALUE = _value * 3 x: uint256 = VALUE + 1 @@ -121,7 +121,7 @@ def test_compilation_success(good_code): """ imm: immutable(uint256) -@external +@deploy def __init__(x: uint256): self.imm = x """, @@ -131,7 +131,7 @@ def __init__(x: uint256): """ imm: immutable(uint256) -@external +@deploy def __init__(x: uint256): x = imm @@ -145,7 +145,7 @@ def report(): """ imm: immutable(uint256) -@external +@deploy def __init__(x: uint256): imm = x @@ -163,7 +163,7 @@ def report(): x: immutable(Foo) -@external +@deploy def __init__(): x = Foo({a:1}) diff --git a/tests/functional/syntax/test_init.py b/tests/functional/syntax/test_init.py new file mode 100644 index 0000000000..389b5ad681 --- /dev/null +++ b/tests/functional/syntax/test_init.py @@ -0,0 +1,64 @@ +import pytest + +from vyper.compiler import compile_code +from vyper.exceptions import FunctionDeclarationException + +good_list = [ + """ +@deploy +def __init__(): + pass + """, + """ +@deploy +@payable +def __init__(): + pass + """, + """ +counter: uint256 +SOME_IMMUTABLE: immutable(uint256) + +@deploy +def __init__(): + SOME_IMMUTABLE = 5 + self.counter = 1 + """, +] + + +@pytest.mark.parametrize("code", good_list) +def test_good_init_funcs(code): + assert compile_code(code) is not None + + +fail_list = [ + """ +@internal +def __init__(): + pass + """, + """ +@deploy +@view +def __init__(): + pass + """, + """ +@deploy +@pure +def __init__(): + pass + """, + """ +@deploy +def some_function(): # for now, only __init__() functions can be marked @deploy + pass + """, +] + + +@pytest.mark.parametrize("code", fail_list) +def test_bad_init_funcs(code): + with pytest.raises(FunctionDeclarationException): + compile_code(code) diff --git a/tests/functional/syntax/test_interfaces.py b/tests/functional/syntax/test_interfaces.py index 584e497534..a07ec4e3dc 100644 --- a/tests/functional/syntax/test_interfaces.py +++ b/tests/functional/syntax/test_interfaces.py @@ -304,7 +304,7 @@ def some_func(): nonpayable my_interface: MyInterface[3] idx: uint256 -@external +@deploy def __init__(): self.my_interface[self.idx] = MyInterface(empty(address)) """, @@ -348,7 +348,7 @@ def foo() -> uint256: view foo: public(immutable(uint256)) -@external +@deploy def __init__(x: uint256): foo = x """, diff --git a/tests/functional/syntax/test_public.py b/tests/functional/syntax/test_public.py index 71bff753f4..217fcea998 100644 --- a/tests/functional/syntax/test_public.py +++ b/tests/functional/syntax/test_public.py @@ -10,7 +10,7 @@ x: public(constant(int128)) = 0 y: public(immutable(int128)) -@external +@deploy def __init__(): y = 0 """, diff --git a/tests/functional/syntax/test_tuple_assign.py b/tests/functional/syntax/test_tuple_assign.py index 49b63ee614..bb23804e30 100644 --- a/tests/functional/syntax/test_tuple_assign.py +++ b/tests/functional/syntax/test_tuple_assign.py @@ -92,7 +92,7 @@ def test(a: bytes32) -> (bytes32, uint256, int128): """ B: immutable(uint256) -@external +@deploy def __init__(b: uint256): B = b diff --git a/tests/unit/ast/test_ast_dict.py b/tests/unit/ast/test_ast_dict.py index 20390f3d5e..9fec61cb90 100644 --- a/tests/unit/ast/test_ast_dict.py +++ b/tests/unit/ast/test_ast_dict.py @@ -109,16 +109,6 @@ def foo() -> uint256: "node_id": 9, "src": "48:15:0", "ast_type": "ImplementsDecl", - "target": { - "col_offset": 0, - "end_col_offset": 10, - "node_id": 10, - "src": "48:10:0", - "ast_type": "Name", - "end_lineno": 5, - "lineno": 5, - "id": "implements", - }, "end_lineno": 5, "lineno": 5, } diff --git a/tests/unit/cli/storage_layout/test_storage_layout.py b/tests/unit/cli/storage_layout/test_storage_layout.py index 1aa8901881..f0ee25f747 100644 --- a/tests/unit/cli/storage_layout/test_storage_layout.py +++ b/tests/unit/cli/storage_layout/test_storage_layout.py @@ -56,7 +56,7 @@ def test_storage_and_immutables_layout(): SYMBOL: immutable(String[32]) DECIMALS: immutable(uint8) -@external +@deploy def __init__(): SYMBOL = "VYPR" DECIMALS = 18 @@ -72,3 +72,251 @@ def __init__(): out = compile_code(code, output_formats=["layout"]) assert out["layout"] == expected_layout + + +def test_storage_layout_module(make_input_bundle): + lib1 = """ +supply: uint256 +SYMBOL: immutable(String[32]) +DECIMALS: immutable(uint8) + +@deploy +def __init__(): + SYMBOL = "VYPR" + DECIMALS = 18 + """ + code = """ +import lib1 as a_library + +counter: uint256 +some_immutable: immutable(DynArray[uint256, 10]) + +counter2: uint256 + +initializes: a_library + +@deploy +def __init__(): + some_immutable = [1, 2, 3] + a_library.__init__() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + expected_layout = { + "code_layout": { + "some_immutable": {"length": 352, "offset": 0, "type": "DynArray[uint256, 10]"}, + "a_library": { + "DECIMALS": {"length": 32, "offset": 416, "type": "uint8"}, + "SYMBOL": {"length": 64, "offset": 352, "type": "String[32]"}, + }, + }, + "storage_layout": { + "counter": {"slot": 0, "type": "uint256"}, + "counter2": {"slot": 1, "type": "uint256"}, + "a_library": {"supply": {"slot": 2, "type": "uint256"}}, + }, + } + + out = compile_code(code, input_bundle=input_bundle, output_formats=["layout"]) + assert out["layout"] == expected_layout + + +def test_storage_layout_module2(make_input_bundle): + # test module storage layout, but initializes is in a different order + lib1 = """ +supply: uint256 +SYMBOL: immutable(String[32]) +DECIMALS: immutable(uint8) + +@deploy +def __init__(): + SYMBOL = "VYPR" + DECIMALS = 18 + """ + code = """ +import lib1 as a_library + +counter: uint256 +some_immutable: immutable(DynArray[uint256, 10]) + +initializes: a_library + +counter2: uint256 + +@deploy +def __init__(): + a_library.__init__() + some_immutable = [1, 2, 3] + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + expected_layout = { + "code_layout": { + "some_immutable": {"length": 352, "offset": 0, "type": "DynArray[uint256, 10]"}, + "a_library": { + "SYMBOL": {"length": 64, "offset": 352, "type": "String[32]"}, + "DECIMALS": {"length": 32, "offset": 416, "type": "uint8"}, + }, + }, + "storage_layout": { + "counter": {"slot": 0, "type": "uint256"}, + "a_library": {"supply": {"slot": 1, "type": "uint256"}}, + "counter2": {"slot": 2, "type": "uint256"}, + }, + } + + out = compile_code(code, input_bundle=input_bundle, output_formats=["layout"]) + assert out["layout"] == expected_layout + + +def test_storage_layout_module_uses(make_input_bundle): + # test module storage layout, with initializes/uses + lib1 = """ +supply: uint256 +SYMBOL: immutable(String[32]) +DECIMALS: immutable(uint8) + +@deploy +def __init__(): + SYMBOL = "VYPR" + DECIMALS = 18 + """ + lib2 = """ +import lib1 + +uses: lib1 + +storage_variable: uint256 +immutable_variable: immutable(uint256) + +@deploy +def __init__(s: uint256): + immutable_variable = s + +@internal +def decimals() -> uint8: + return lib1.DECIMALS + """ + code = """ +import lib1 as a_library +import lib2 + +counter: uint256 +some_immutable: immutable(DynArray[uint256, 10]) + +# for fun: initialize lib2 in front of lib1 +initializes: lib2[lib1 := a_library] + +counter2: uint256 + +initializes: a_library + +@deploy +def __init__(): + a_library.__init__() + some_immutable = [1, 2, 3] + + lib2.__init__(17) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + expected_layout = { + "code_layout": { + "some_immutable": {"length": 352, "offset": 0, "type": "DynArray[uint256, 10]"}, + "lib2": {"immutable_variable": {"length": 32, "offset": 352, "type": "uint256"}}, + "a_library": { + "SYMBOL": {"length": 64, "offset": 384, "type": "String[32]"}, + "DECIMALS": {"length": 32, "offset": 448, "type": "uint8"}, + }, + }, + "storage_layout": { + "counter": {"slot": 0, "type": "uint256"}, + "lib2": {"storage_variable": {"slot": 1, "type": "uint256"}}, + "counter2": {"slot": 2, "type": "uint256"}, + "a_library": {"supply": {"slot": 3, "type": "uint256"}}, + }, + } + + out = compile_code(code, input_bundle=input_bundle, output_formats=["layout"]) + assert out["layout"] == expected_layout + + +def test_storage_layout_module_nested_initializes(make_input_bundle): + # test module storage layout, with initializes in an imported module + lib1 = """ +supply: uint256 +SYMBOL: immutable(String[32]) +DECIMALS: immutable(uint8) + +@deploy +def __init__(): + SYMBOL = "VYPR" + DECIMALS = 18 + """ + lib2 = """ +import lib1 + +initializes: lib1 + +storage_variable: uint256 +immutable_variable: immutable(uint256) + +@deploy +def __init__(s: uint256): + immutable_variable = s + lib1.__init__() + +@internal +def decimals() -> uint8: + return lib1.DECIMALS + """ + code = """ +import lib1 as a_library +import lib2 + +counter: uint256 +some_immutable: immutable(DynArray[uint256, 10]) + +# for fun: initialize lib2 in front of lib1 +initializes: lib2 + +counter2: uint256 + +uses: a_library + +@deploy +def __init__(): + some_immutable = [1, 2, 3] + + lib2.__init__(17) + +@external +def foo() -> uint256: + return a_library.supply + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + expected_layout = { + "code_layout": { + "some_immutable": {"length": 352, "offset": 0, "type": "DynArray[uint256, 10]"}, + "lib2": { + "lib1": { + "SYMBOL": {"length": 64, "offset": 352, "type": "String[32]"}, + "DECIMALS": {"length": 32, "offset": 416, "type": "uint8"}, + }, + "immutable_variable": {"length": 32, "offset": 448, "type": "uint256"}, + }, + }, + "storage_layout": { + "counter": {"slot": 0, "type": "uint256"}, + "lib2": { + "lib1": {"supply": {"slot": 1, "type": "uint256"}}, + "storage_variable": {"slot": 2, "type": "uint256"}, + }, + "counter2": {"slot": 3, "type": "uint256"}, + }, + } + + out = compile_code(code, input_bundle=input_bundle, output_formats=["layout"]) + assert out["layout"] == expected_layout diff --git a/tests/unit/compiler/asm/test_asm_optimizer.py b/tests/unit/compiler/asm/test_asm_optimizer.py index b2851e908a..ce32249202 100644 --- a/tests/unit/compiler/asm/test_asm_optimizer.py +++ b/tests/unit/compiler/asm/test_asm_optimizer.py @@ -20,7 +20,7 @@ def runtime_only(): def bar(): self.runtime_only() -@external +@deploy def __init__(): self.ctor_only() """, @@ -44,7 +44,7 @@ def ctor_only(): def bar(): self.foo() -@external +@deploy def __init__(): self.ctor_only() """, @@ -65,7 +65,7 @@ def runtime_only(): def bar(): self.runtime_only() -@external +@deploy def __init__(): self.ctor_only() """, @@ -73,6 +73,9 @@ def __init__(): # check dead code eliminator works on unreachable functions +# CMC 2024-02-05 this is not really the asm eliminator anymore, +# it happens during function code generation in module.py. so we don't +# need to test this using asm anymore. @pytest.mark.parametrize("code", codes) def test_dead_code_eliminator(code): c = CompilerData(code, settings=Settings(optimize=OptimizationLevel.NONE)) @@ -88,20 +91,9 @@ def test_dead_code_eliminator(code): assert any(ctor_only in instr for instr in initcode_asm) assert all(runtime_only not in instr for instr in initcode_asm) - # all labels should be in unoptimized runtime asm - for s in (ctor_only, runtime_only): - assert any(s in instr for instr in runtime_asm) - - c = CompilerData(code, settings=Settings(optimize=OptimizationLevel.GAS)) - initcode_asm = [i for i in c.assembly if isinstance(i, str)] - runtime_asm = [i for i in c.assembly_runtime if isinstance(i, str)] - - # ctor only label should not be in runtime code + assert any(runtime_only in instr for instr in runtime_asm) assert all(ctor_only not in instr for instr in runtime_asm) - # runtime only label should not be in initcode asm - assert all(runtime_only not in instr for instr in initcode_asm) - def test_library_code_eliminator(make_input_bundle): library = """ diff --git a/tests/unit/compiler/test_bytecode_runtime.py b/tests/unit/compiler/test_bytecode_runtime.py index 613ee4d2b8..64cee3a75c 100644 --- a/tests/unit/compiler/test_bytecode_runtime.py +++ b/tests/unit/compiler/test_bytecode_runtime.py @@ -35,7 +35,7 @@ def foo5(): has_immutables = """ A_GOOD_PRIME: public(immutable(uint256)) -@external +@deploy def __init__(): A_GOOD_PRIME = 967 """ diff --git a/tests/unit/semantics/analysis/test_for_loop.py b/tests/unit/semantics/analysis/test_for_loop.py index 607587cc28..c97c9c095e 100644 --- a/tests/unit/semantics/analysis/test_for_loop.py +++ b/tests/unit/semantics/analysis/test_for_loop.py @@ -134,6 +134,111 @@ def baz(): validate_semantics(vyper_module, dummy_input_bundle) +def test_modify_iterator_recursive_function_call_topsort(dummy_input_bundle): + # test the analysis works no matter the order of functions + code = """ +a: uint256[3] + +@internal +def baz(): + for i: uint256 in self.a: + self.bar() + +@internal +def bar(): + self.foo() + +@internal +def foo(): + self.a[0] = 1 + """ + vyper_module = parse_to_ast(code) + with pytest.raises(ImmutableViolation) as e: + validate_semantics(vyper_module, dummy_input_bundle) + + assert e.value._message == "Cannot modify loop variable `a`" + + +def test_modify_iterator_through_struct(dummy_input_bundle): + # GH issue 3429 + code = """ +struct A: + iter: DynArray[uint256, 5] + +a: A + +@external +def foo(): + self.a.iter = [1, 2, 3] + for i: uint256 in self.a.iter: + self.a = A({iter: [1, 2, 3, 4]}) + """ + vyper_module = parse_to_ast(code) + with pytest.raises(ImmutableViolation) as e: + validate_semantics(vyper_module, dummy_input_bundle) + + assert e.value._message == "Cannot modify loop variable `a`" + + +def test_modify_iterator_complex_expr(dummy_input_bundle): + # GH issue 3429 + # avoid false positive! + code = """ +a: DynArray[uint256, 5] +b: uint256[10] + +@external +def foo(): + self.a = [1, 2, 3] + for i: uint256 in self.a: + self.b[self.a[1]] = i + """ + vyper_module = parse_to_ast(code) + validate_semantics(vyper_module, dummy_input_bundle) + + +def test_modify_iterator_siblings(dummy_input_bundle): + # test we can modify siblings in an access tree + code = """ +struct Foo: + a: uint256[2] + b: uint256 + +f: Foo + +@external +def foo(): + for i: uint256 in self.f.a: + self.f.b += i + """ + vyper_module = parse_to_ast(code) + validate_semantics(vyper_module, dummy_input_bundle) + + +def test_modify_subscript_barrier(dummy_input_bundle): + # test that Subscript nodes are a barrier for analysis + code = """ +struct Foo: + x: uint256[2] + y: uint256 + +struct Bar: + f: Foo[2] + +b: Bar + +@external +def foo(): + for i: uint256 in self.b.f[1].x: + self.b.f[0].y += i + """ + vyper_module = parse_to_ast(code) + with pytest.raises(ImmutableViolation) as e: + validate_semantics(vyper_module, dummy_input_bundle) + + assert e.value._message == "Cannot modify loop variable `b`" + + iterator_inference_codes = [ """ @external diff --git a/tests/unit/semantics/test_storage_slots.py b/tests/unit/semantics/test_storage_slots.py index ea2b2fe559..3620ef64b9 100644 --- a/tests/unit/semantics/test_storage_slots.py +++ b/tests/unit/semantics/test_storage_slots.py @@ -25,7 +25,7 @@ h: public(int256[1]) -@external +@deploy def __init__(): self.a = StructOne({a: "ok", b: [4,5,6]}) self.b = [7, 8] @@ -110,6 +110,6 @@ def test_allocator_overflow(get_contract): """ with pytest.raises( StorageLayoutException, - match=f"Invalid storage slot for var y, tried to allocate slots 1 through {2**256}", + match=f"Invalid storage slot, tried to allocate slots 1 through {2**256}", ): get_contract(code) diff --git a/vyper/ast/__init__.py b/vyper/ast/__init__.py index bc08626b59..0ae93e9710 100644 --- a/vyper/ast/__init__.py +++ b/vyper/ast/__init__.py @@ -5,7 +5,7 @@ from . import nodes, validation from .natspec import parse_natspec -from .nodes import compare_nodes +from .nodes import compare_nodes, as_tuple from .utils import ast_to_dict from .parse import parse_to_ast, parse_to_ast_with_settings @@ -15,6 +15,5 @@ ): setattr(sys.modules[__name__], name, obj) - # required to avoid circular dependency from . import expansion # noqa: E402 diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index 84429501e1..5ad465a1f1 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -182,13 +182,9 @@ loop_variable: NAME ":" type loop_iterator: _expr for_stmt: "for" loop_variable "in" loop_iterator ":" body -// ternary operator -ternary: _expr "if" _expr "else" _expr - // Expressions _expr: operation | dict - | ternary get_item: (variable_access | list) "[" _expr "]" get_attr: variable_access "." NAME @@ -214,7 +210,15 @@ dict: "{" "}" | "{" (NAME ":" _expr) ("," (NAME ":" _expr))* [","] "}" // See https://docs.python.org/3/reference/expressions.html#operator-precedence // NOTE: The recursive cycle here helps enforce operator precedence // Precedence goes up the lower down you go -?operation: bool_or +?operation: assignment_expr + +// "walrus" operator +?assignment_expr: ternary + | NAME ":=" assignment_expr + +// ternary operator +?ternary: bool_or + | ternary "if" ternary "else" ternary _AND: "and" _OR: "or" diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 054145d33b..c4bce814a4 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -83,8 +83,20 @@ def get_node( if ast_struct["value"] is not None: _raise_syntax_exc("`implements` cannot have a value assigned", ast_struct) ast_struct["ast_type"] = "ImplementsDecl" + + # Replace "uses:" `AnnAssign` nodes with `UsesDecl` + elif getattr(ast_struct["target"], "id", None) == "uses": + if ast_struct["value"] is not None: + _raise_syntax_exc("`uses` cannot have a value assigned", ast_struct) + ast_struct["ast_type"] = "UsesDecl" + + # Replace "initializes:" `AnnAssign` nodes with `InitializesDecl` + elif getattr(ast_struct["target"], "id", None) == "initializes": + if ast_struct["value"] is not None: + _raise_syntax_exc("`initializes` cannot have a value assigned", ast_struct) + ast_struct["ast_type"] = "InitializesDecl" + # Replace state and local variable declarations `AnnAssign` with `VariableDecl` - # Parent node is required for context to determine whether replacement should happen. else: ast_struct["ast_type"] = "VariableDecl" @@ -730,6 +742,20 @@ def is_terminus(self): return self.value.is_terminus +class NamedExpr(Stmt): + __slots__ = ("target", "value") + + def validate(self): + # module[dep1 := dep2] + + # XXX: better error messages + if not isinstance(self.target, Name): + raise StructureException("not a Name") + + if not isinstance(self.value, Name): + raise StructureException("not a Name") + + class Log(Stmt): __slots__ = ("value",) @@ -756,6 +782,11 @@ class StructDef(TopLevel): class ExprNode(VyperNode): __slots__ = ("_expr_info",) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._expr_info = None + class Constant(ExprNode): # inherited class for all simple constant node types @@ -1383,17 +1414,13 @@ class ImplementsDecl(Stmt): """ An `implements` declaration. - Excludes `simple` and `value` attributes from Python `AnnAssign` node. - Attributes ---------- - target : Name - Name node for the `implements` keyword annotation : Name Name node for the interface to be implemented """ - __slots__ = ("target", "annotation") + __slots__ = ("annotation",) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -1402,6 +1429,72 @@ def __init__(self, *args, **kwargs): raise StructureException("invalid implements", self.annotation) +def as_tuple(node: VyperNode): + """ + Convenience function for some AST nodes which allow either a Tuple + or single elements. Returns a python tuple of AST nodes. + """ + if isinstance(node, Tuple): + return node.elements + else: + return (node,) + + +class UsesDecl(Stmt): + """ + A `uses` declaration. + + Attributes + ---------- + annotation : Name | Attribute | Tuple + The module(s) which this uses + """ + + __slots__ = ("annotation",) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + items = as_tuple(self.annotation) + for item in items: + if not isinstance(item, (Name, Attribute)): + raise StructureException("invalid uses", item) + + +class InitializesDecl(Stmt): + """ + An `initializes` declaration. + + Attributes + ---------- + annotation : Name | Attribute | Subscript + An imported module which this module initializes + """ + + __slots__ = ("annotation",) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + module_ref = self.annotation + if isinstance(module_ref, Subscript): + dependencies = as_tuple(module_ref.slice) + module_ref = module_ref.value + + for item in dependencies: + if not isinstance(item, NamedExpr): + raise StructureException( + "invalid dependency (hint: should be [dependency := dependency]", item + ) + if not isinstance(item.target, (Name, Attribute)): + raise StructureException("invalid module", item.target) + if not isinstance(item.value, (Name, Attribute)): + raise StructureException("invalid module", item.target) + + if not isinstance(module_ref, (Name, Attribute)): + raise StructureException("invalid module", module_ref) + + class If(Stmt): __slots__ = ("test", "body", "orelse") diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index f71ed67821..342c84876a 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -101,7 +101,8 @@ class StructDef(VyperNode): body: list = ... name: str = ... -class ExprNode(VyperNode): ... +class ExprNode(VyperNode): + _expr_info: Any = ... class Constant(VyperNode): value: Any = ... @@ -145,19 +146,19 @@ class Name(VyperNode): _type: str = ... class Expr(VyperNode): - value: VyperNode = ... + value: ExprNode = ... class UnaryOp(ExprNode): op: VyperNode = ... - operand: VyperNode = ... + operand: ExprNode = ... class USub(VyperNode): ... class Not(VyperNode): ... class BinOp(ExprNode): - left: VyperNode = ... op: VyperNode = ... - right: VyperNode = ... + left: ExprNode = ... + right: ExprNode = ... class Add(VyperNode): ... class Sub(VyperNode): ... @@ -173,15 +174,15 @@ class BitXor(VyperNode): ... class BoolOp(ExprNode): op: VyperNode = ... - values: list[VyperNode] = ... + values: list[ExprNode] = ... class And(VyperNode): ... class Or(VyperNode): ... class Compare(ExprNode): op: VyperNode = ... - left: VyperNode = ... - right: VyperNode = ... + left: ExprNode = ... + right: ExprNode = ... class Eq(VyperNode): ... class NotEq(VyperNode): ... @@ -195,17 +196,17 @@ class NotIn(VyperNode): ... class Call(ExprNode): args: list = ... keywords: list = ... - func: VyperNode = ... + func: ExprNode = ... class keyword(VyperNode): ... -class Attribute(VyperNode): +class Attribute(ExprNode): attr: str = ... - value: VyperNode = ... + value: ExprNode = ... -class Subscript(VyperNode): - slice: VyperNode = ... - value: VyperNode = ... +class Subscript(ExprNode): + slice: ExprNode = ... + value: ExprNode = ... class Assign(VyperNode): ... @@ -224,8 +225,8 @@ class VariableDecl(VyperNode): class AugAssign(VyperNode): op: VyperNode = ... - target: VyperNode = ... - value: VyperNode = ... + target: ExprNode = ... + value: ExprNode = ... class Raise(VyperNode): ... class Assert(VyperNode): ... @@ -245,6 +246,12 @@ class ImplementsDecl(VyperNode): target: Name = ... annotation: Name = ... +class UsesDecl(VyperNode): + annotation: VyperNode = ... + +class InitializesDecl(VyperNode): + annotation: VyperNode = ... + class If(VyperNode): body: list = ... orelse: list = ... @@ -254,6 +261,10 @@ class IfExp(ExprNode): body: ExprNode = ... orelse: ExprNode = ... +class NamedExpr(ExprNode): + target: Name = ... + value: ExprNode = ... + class For(VyperNode): target: ExprNode iter: ExprNode diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index fc99af901b..a10a840da0 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -278,8 +278,8 @@ def visit_For(self, node): # specific error message than "invalid type annotation" raise SyntaxException( "missing type annotation\n\n" - "(hint: did you mean something like " - f"`for {node.target.id}: uint256 in ...`?)\n", + " (hint: did you mean something like " + f"`for {node.target.id}: uint256 in ...`?)", self._source_code, node.lineno, node.col_offset, diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index d2aefb2fd4..6e6cf4c662 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -85,13 +85,16 @@ class BuiltinFunctionT(VyperType): _kwargs: dict[str, KwargSettings] = {} _modifiability: Modifiability = Modifiability.MODIFIABLE _return_type: Optional[VyperType] = None + _equality_attrs = ("_id",) _is_terminus = False - # helper function to deal with TYPE_DEFINITIONs + @property + def modifiability(self): + return self._modifiability + + # helper function to deal with TYPE_Ts def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> None: - # TODO using "TYPE_DEFINITION" is a kludge in derived classes, - # refactor me. - if expected_type == "TYPE_DEFINITION": + if TYPE_T.any().compare_type(expected_type): # try to parse the type - call type_from_annotation # for its side effects (will throw if is not a type) type_from_annotation(arg) @@ -130,7 +133,7 @@ def _validate_arg_types(self, node: vy_ast.Call) -> None: get_exact_type_from_node(arg) def check_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifiability) -> bool: - return self._modifiability >= modifiability + return self._modifiability <= modifiability def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: self._validate_arg_types(node) diff --git a/vyper/builtins/_utils.py b/vyper/builtins/_utils.py index 72b05f15e3..3fad225b48 100644 --- a/vyper/builtins/_utils.py +++ b/vyper/builtins/_utils.py @@ -1,7 +1,7 @@ from vyper.ast import parse_to_ast from vyper.codegen.context import Context from vyper.codegen.stmt import parse_body -from vyper.semantics.analysis.local import FunctionNodeVisitor +from vyper.semantics.analysis.local import FunctionAnalyzer from vyper.semantics.namespace import Namespace, override_global_namespace from vyper.semantics.types.function import ContractFunctionT, FunctionVisibility, StateMutability from vyper.semantics.types.module import ModuleT @@ -25,9 +25,7 @@ def generate_inline_function(code, variables, variables_2, memory_allocator): ast_code.body[0]._metadata["func_type"] = ContractFunctionT( "sqrt_builtin", [], [], None, FunctionVisibility.INTERNAL, StateMutability.NONPAYABLE ) - # The FunctionNodeVisitor's constructor performs semantic checks - # annotate the AST as side effects. - analyzer = FunctionNodeVisitor(ast_code, ast_code.body[0], namespace) + analyzer = FunctionAnalyzer(ast_code, ast_code.body[0], namespace) analyzer.analyze() new_context = Context( diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 50ab4dacd8..7575f4d77e 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -113,10 +113,7 @@ class TypenameFoldedFunctionT(FoldedFunctionT): # Base class for builtin functions that: # (1) take a typename as the only argument; and # (2) should always be folded. - - # "TYPE_DEFINITION" is a placeholder value for a type definition string, and - # will be replaced by a `TypeTypeDefinition` object in `infer_arg_types`. - _inputs = [("typename", "TYPE_DEFINITION")] + _inputs = [("typename", TYPE_T.any())] def fetch_call_return(self, node): type_ = self.infer_arg_types(node)[0].typedef @@ -711,7 +708,7 @@ def build_IR(self, expr, args, kwargs, context): class MethodID(FoldedFunctionT): _id = "method_id" _inputs = [("value", StringT.any())] - _kwargs = {"output_type": KwargSettings("TYPE_DEFINITION", BytesT(4))} + _kwargs = {"output_type": KwargSettings(TYPE_T.any(), BytesT(4))} def _try_fold(self, node): validate_call_args(node, 1, ["output_type"]) @@ -848,10 +845,7 @@ def _storage_element_getter(index): class Extract32(BuiltinFunctionT): _id = "extract32" _inputs = [("b", BytesT.any()), ("start", IntegerT.unsigneds())] - # "TYPE_DEFINITION" is a placeholder value for a type definition string, and - # will be replaced by a `TYPE_T` object in `infer_kwarg_types` - # (note that it is ignored in _validate_arg_types) - _kwargs = {"output_type": KwargSettings("TYPE_DEFINITION", BYTES32_T)} + _kwargs = {"output_type": KwargSettings(TYPE_T.any(), BYTES32_T)} def fetch_call_return(self, node): self._validate_arg_types(node) @@ -1976,18 +1970,22 @@ def build_IR(self, expr, args, kwargs, context): class UnsafeAdd(_UnsafeMath): + _id = "unsafe_add" op = "add" class UnsafeSub(_UnsafeMath): + _id = "unsafe_sub" op = "sub" class UnsafeMul(_UnsafeMath): + _id = "unsafe_mul" op = "mul" class UnsafeDiv(_UnsafeMath): + _id = "unsafe_div" op = "div" @@ -2474,7 +2472,7 @@ def build_IR(self, expr, args, kwargs, context): class ABIDecode(BuiltinFunctionT): _id = "_abi_decode" - _inputs = [("data", BytesT.any()), ("output_type", "TYPE_DEFINITION")] + _inputs = [("data", BytesT.any()), ("output_type", TYPE_T.any())] _kwargs = {"unwrap_tuple": KwargSettings(BoolT(), True, require_literal=True)} def fetch_call_return(self, node): diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py index 4f644841f4..af01c5b504 100644 --- a/vyper/codegen/context.py +++ b/vyper/codegen/context.py @@ -44,7 +44,7 @@ def __repr__(self): return f"VariableRecord({ret})" -# Contains arguments, variables, etc +# compilation context for a function class Context: def __init__( self, @@ -59,19 +59,12 @@ def __init__( # In-memory variables, in the form (name, memory location, type) self.vars = vars_ or {} - # Global variables, in the form (name, storage location, type) - self.globals = module_ctx.variables - # Variables defined in for loops, e.g. for i in range(6): ... self.forvars = forvars or {} # Is the function constant? self.constancy = constancy - # Whether body is currently in an assert statement - # XXX: dead, never set to True - self.in_assertion = False - # Whether we are currently parsing a range expression self.in_range_expr = False @@ -87,6 +80,10 @@ def __init__( # Not intended to be accessed directly self.memory_allocator = memory_allocator + # save the starting memory location so we can find out (later) + # how much memory this function uses. + self.starting_memory = memory_allocator.next_mem + # Incremented values, used for internal IDs self._internal_var_iter = 0 self._scope_id_iter = 0 @@ -95,7 +92,7 @@ def __init__( self.is_ctor_context = is_ctor_context def is_constant(self): - return self.constancy is Constancy.Constant or self.in_assertion or self.in_range_expr + return self.constancy is Constancy.Constant or self.in_range_expr def check_is_not_constant(self, err, expr): if self.is_constant(): @@ -250,9 +247,7 @@ def lookup_var(self, varname): # Pretty print constancy for error messages def pp_constancy(self): - if self.in_assertion: - return "an assertion" - elif self.in_range_expr: + if self.in_range_expr: return "a range expression" elif self.constancy == Constancy.Constant: return "a constant function" diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index c3215f8c16..1a090ac316 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -3,9 +3,18 @@ from vyper.codegen.ir_node import Encoding, IRnode from vyper.compiler.settings import OptimizationLevel -from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT +from vyper.evm.address_space import ( + CALLDATA, + DATA, + IMMUTABLES, + MEMORY, + STORAGE, + TRANSIENT, + AddrSpace, +) from vyper.evm.opcodes import version_check from vyper.exceptions import CompilerPanic, TypeCheckFailure, TypeMismatch +from vyper.semantics.data_locations import DataLocation from vyper.semantics.types import ( AddressT, BoolT, @@ -100,6 +109,36 @@ def _codecopy_gas_bound(num_bytes): return GAS_COPY_WORD * ceil32(num_bytes) // 32 +def data_location_to_address_space(s: DataLocation, is_ctor_ctx: bool) -> AddrSpace: + if s == DataLocation.MEMORY: + return MEMORY + if s == DataLocation.STORAGE: + return STORAGE + if s == DataLocation.TRANSIENT: + return TRANSIENT + if s == DataLocation.CODE: + if is_ctor_ctx: + return IMMUTABLES + return DATA + + raise CompilerPanic("unreachable!") # pragma: nocover + + +def address_space_to_data_location(s: AddrSpace) -> DataLocation: + if s == MEMORY: + return DataLocation.MEMORY + if s == STORAGE: + return DataLocation.STORAGE + if s == TRANSIENT: + return DataLocation.TRANSIENT + if s in (IMMUTABLES, DATA): + return DataLocation.CODE + if s == CALLDATA: + return DataLocation.CALLDATA + + raise CompilerPanic("unreachable!") # pragma: nocover + + # Copy byte array word-for-word (including layout) # TODO make this a private function def make_byte_array_copier(dst, src): @@ -482,14 +521,10 @@ def _get_element_ptr_tuplelike(parent, key): return _getelemptr_abi_helper(parent, member_t, ofst) - if parent.location.word_addressable: - for i in range(index): - ofst += typ.member_types[attrs[i]].storage_size_in_words - elif parent.location.byte_addressable: - for i in range(index): - ofst += typ.member_types[attrs[i]].memory_bytes_required - else: - raise CompilerPanic(f"bad location {parent.location}") # pragma: notest + data_location = address_space_to_data_location(parent.location) + for i in range(index): + t = typ.member_types[attrs[i]] + ofst += t.get_size_in(data_location) return IRnode.from_list( add_ofst(parent, ofst), @@ -550,12 +585,8 @@ def _get_element_ptr_array(parent, key, array_bounds_check): return _getelemptr_abi_helper(parent, subtype, ofst) - if parent.location.word_addressable: - element_size = subtype.storage_size_in_words - elif parent.location.byte_addressable: - element_size = subtype.memory_bytes_required - else: - raise CompilerPanic("unreachable") # pragma: notest + data_location = address_space_to_data_location(parent.location) + element_size = subtype.get_size_in(data_location) ofst = _mul(ix, element_size) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index f4c7948382..9c7f11dcb3 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -6,6 +6,7 @@ from vyper.codegen import external_call, self_call from vyper.codegen.core import ( clamp, + data_location_to_address_space, ensure_in_memory, get_dyn_array_count, get_element_ptr, @@ -23,7 +24,7 @@ ) from vyper.codegen.ir_node import IRnode from vyper.codegen.keccak256_helper import keccak256_helper -from vyper.evm.address_space import DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT +from vyper.evm.address_space import MEMORY from vyper.evm.opcodes import version_check from vyper.exceptions import ( CodegenPanic, @@ -185,26 +186,24 @@ def parse_Name(self): ret._referenced_variables = {var} return ret - # TODO: use self.expr._expr_info - elif self.expr.id in self.context.globals: - varinfo = self.context.globals[self.expr.id] - + elif (varinfo := self.expr._expr_info.var_info) is not None: if varinfo.is_constant: return Expr.parse_value_expr(varinfo.decl_node.value, self.context) assert varinfo.is_immutable, "not an immutable!" - ofst = varinfo.position.offset + mutable = self.context.is_ctor_context - if self.context.is_ctor_context: - mutable = True - location = IMMUTABLES - else: - mutable = False - location = DATA + location = data_location_to_address_space( + varinfo.location, self.context.is_ctor_context + ) ret = IRnode.from_list( - ofst, typ=varinfo.typ, location=location, annotation=self.expr.id, mutable=mutable + varinfo.position.position, + typ=varinfo.typ, + location=location, + annotation=self.expr.id, + mutable=mutable, ) ret._referenced_variables = {varinfo} return ret @@ -264,20 +263,6 @@ def parse_Attribute(self): if addr.value == "address": # for `self.code` return IRnode.from_list(["~selfcode"], typ=BytesT(0)) return IRnode.from_list(["~extcode", addr], typ=BytesT(0)) - # self.x: global attribute - elif isinstance(self.expr.value, vy_ast.Name) and self.expr.value.id == "self": - varinfo = self.context.globals[self.expr.attr] - location = TRANSIENT if varinfo.is_transient else STORAGE - - ret = IRnode.from_list( - varinfo.position.position, - typ=varinfo.typ, - location=location, - annotation="self." + self.expr.attr, - ) - ret._referenced_variables = {varinfo} - - return ret # Reserved keywords elif ( @@ -333,17 +318,37 @@ def parse_Attribute(self): "chain.id is unavailable prior to istanbul ruleset", self.expr ) return IRnode.from_list(["chainid"], typ=UINT256_T) + # Other variables - else: - sub = Expr(self.expr.value, self.context).ir_node - # contract type - if isinstance(sub.typ, InterfaceT): - # MyInterface.address - assert self.expr.attr == "address" - sub.typ = typ - return sub - if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types: - return get_element_ptr(sub, self.expr.attr) + + # self.x: global attribute + if (varinfo := self.expr._expr_info.var_info) is not None: + if varinfo.is_constant: + return Expr.parse_value_expr(varinfo.decl_node.value, self.context) + + location = data_location_to_address_space( + varinfo.location, self.context.is_ctor_context + ) + + ret = IRnode.from_list( + varinfo.position.position, + typ=varinfo.typ, + location=location, + annotation="self." + self.expr.attr, + ) + ret._referenced_variables = {varinfo} + + return ret + + sub = Expr(self.expr.value, self.context).ir_node + # contract type + if isinstance(sub.typ, InterfaceT): + # MyInterface.address + assert self.expr.attr == "address" + sub.typ = typ + return sub + if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types: + return get_element_ptr(sub, self.expr.attr) def parse_Subscript(self): sub = Expr(self.expr.value, self.context).ir_node @@ -700,7 +705,7 @@ def parse_Call(self): return pop_dyn_array(darray, return_popped_item=True) if isinstance(func_type, ContractFunctionT): - if func_type.is_internal: + if func_type.is_internal or func_type.is_constructor: return self_call.ir_for_self_call(self.expr, self.context) else: return external_call.ir_for_external_call(self.expr, self.context) diff --git a/vyper/codegen/function_definitions/__init__.py b/vyper/codegen/function_definitions/__init__.py index 94617bef35..254b4df72c 100644 --- a/vyper/codegen/function_definitions/__init__.py +++ b/vyper/codegen/function_definitions/__init__.py @@ -1 +1,4 @@ -from .common import FuncIR, generate_ir_for_function # noqa +from .external_function import generate_ir_for_external_function +from .internal_function import generate_ir_for_internal_function + +__all__ = [generate_ir_for_internal_function, generate_ir_for_external_function] # type: ignore diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index 5877ff3d13..d017ba7b81 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -2,17 +2,14 @@ from functools import cached_property from typing import Optional -import vyper.ast as vy_ast from vyper.codegen.context import Constancy, Context -from vyper.codegen.function_definitions.external_function import generate_ir_for_external_function -from vyper.codegen.function_definitions.internal_function import generate_ir_for_internal_function from vyper.codegen.ir_node import IRnode from vyper.codegen.memory_allocator import MemoryAllocator -from vyper.exceptions import CompilerPanic +from vyper.evm.opcodes import version_check from vyper.semantics.types import VyperType -from vyper.semantics.types.function import ContractFunctionT +from vyper.semantics.types.function import ContractFunctionT, StateMutability from vyper.semantics.types.module import ModuleT -from vyper.utils import MemoryPositions, calc_mem_gas +from vyper.utils import MemoryPositions @dataclass @@ -53,9 +50,11 @@ def ir_identifier(self) -> str: return f"{self.visibility} {function_id} {name}({argz})" def set_frame_info(self, frame_info: FrameInfo) -> None: + # XXX: when can this happen? if self.frame_info is not None: - raise CompilerPanic(f"frame_info already set for {self.func_t}!") - self.frame_info = frame_info + assert frame_info == self.frame_info + else: + self.frame_info = frame_info @property # common entry point for external function with kwargs @@ -64,13 +63,15 @@ def external_function_base_entry_label(self) -> str: return self.ir_identifier + "_common" def internal_function_label(self, is_ctor_context: bool = False) -> str: - assert self.func_t.is_internal, "uh oh, should be internal" - suffix = "_deploy" if is_ctor_context else "_runtime" - return self.ir_identifier + suffix + f = self.func_t + assert f.is_internal or f.is_constructor, "uh oh, should be internal" + if f.is_constructor: + # sanity check - imported init functions only callable from main init + assert is_ctor_context -class FuncIR: - pass + suffix = "_deploy" if is_ctor_context else "_runtime" + return self.ir_identifier + suffix @dataclass @@ -80,7 +81,7 @@ class EntryPointInfo: ir_node: IRnode # the ir for this entry point def __post_init__(self): - # ABI v2 property guaranteed by the spec. + # sanity check ABI v2 properties guaranteed by the spec. # https://docs.soliditylang.org/en/v0.8.21/abi-spec.html#formal-specification-of-the-encoding states: # noqa: E501 # > Note that for any X, len(enc(X)) is a multiple of 32. assert self.min_calldatasize >= 4 @@ -88,34 +89,28 @@ def __post_init__(self): @dataclass -class ExternalFuncIR(FuncIR): +class ExternalFuncIR: entry_points: dict[str, EntryPointInfo] # map from abi sigs to entry points common_ir: IRnode # the "common" code for the function @dataclass -class InternalFuncIR(FuncIR): +class InternalFuncIR: func_ir: IRnode # the code for the function -# TODO: should split this into external and internal ir generation? -def generate_ir_for_function( - code: vy_ast.FunctionDef, module_ctx: ModuleT, is_ctor_context: bool = False -) -> FuncIR: - """ - Parse a function and produce IR code for the function, includes: - - Signature method if statement - - Argument handling - - Clamping and copying of arguments - - Function body - """ - func_t = code._metadata["func_type"] - - # generate _FuncIRInfo +def init_ir_info(func_t: ContractFunctionT): + # initialize IRInfo on the function func_t._ir_info = _FuncIRInfo(func_t) - callees = func_t.called_functions +def initialize_context( + func_t: ContractFunctionT, module_ctx: ModuleT, is_ctor_context: bool = False +): + init_ir_info(func_t) + + # calculate starting frame + callees = func_t.called_functions # we start our function frame from the largest callee frame max_callee_frame_size = 0 for c_func_t in callees: @@ -126,7 +121,7 @@ def generate_ir_for_function( memory_allocator = MemoryAllocator(allocate_start) - context = Context( + return Context( vars_=None, module_ctx=module_ctx, memory_allocator=memory_allocator, @@ -135,38 +130,41 @@ def generate_ir_for_function( is_ctor_context=is_ctor_context, ) - if func_t.is_internal: - ret: FuncIR = InternalFuncIR(generate_ir_for_internal_function(code, func_t, context)) - func_t._ir_info.gas_estimate = ret.func_ir.gas # type: ignore - else: - kwarg_handlers, common = generate_ir_for_external_function(code, func_t, context) - entry_points = { - k: EntryPointInfo(func_t, mincalldatasize, ir_node) - for k, (mincalldatasize, ir_node) in kwarg_handlers.items() - } - ret = ExternalFuncIR(entry_points, common) - # note: this ignores the cost of traversing selector table - func_t._ir_info.gas_estimate = ret.common_ir.gas +def tag_frame_info(func_t, context): frame_size = context.memory_allocator.size_of_mem - MemoryPositions.RESERVED_MEMORY + frame_start = context.starting_memory - frame_info = FrameInfo(allocate_start, frame_size, context.vars) + frame_info = FrameInfo(frame_start, frame_size, context.vars) + func_t._ir_info.set_frame_info(frame_info) - # XXX: when can this happen? - if func_t._ir_info.frame_info is None: - func_t._ir_info.set_frame_info(frame_info) - else: - assert frame_info == func_t._ir_info.frame_info - - if not func_t.is_internal: - # adjust gas estimate to include cost of mem expansion - # frame_size of external function includes all private functions called - # (note: internal functions do not need to adjust gas estimate since - mem_expansion_cost = calc_mem_gas(func_t._ir_info.frame_info.mem_used) # type: ignore - ret.common_ir.add_gas_estimate += mem_expansion_cost # type: ignore - ret.common_ir.passthrough_metadata["func_t"] = func_t # type: ignore - ret.common_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore + return frame_info + + +def get_nonreentrant_lock(func_t): + if not func_t.nonreentrant: + return ["pass"], ["pass"] + + nkey = func_t.reentrancy_key_position.position + + LOAD, STORE = "sload", "sstore" + if version_check(begin="cancun"): + LOAD, STORE = "tload", "tstore" + + if version_check(begin="berlin"): + # any nonzero values would work here (see pricing as of net gas + # metering); these values are chosen so that downgrading to the + # 0,1 scheme (if it is somehow necessary) is safe. + final_value, temp_value = 3, 2 else: - ret.func_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore + final_value, temp_value = 0, 1 + + check_notset = ["assert", ["ne", temp_value, [LOAD, nkey]]] - return ret + if func_t.mutability == StateMutability.VIEW: + return [check_notset], [["seq"]] + + else: + pre = ["seq", check_notset, [STORE, nkey, temp_value]] + post = [STORE, nkey, final_value] + return [pre], [post] diff --git a/vyper/codegen/function_definitions/external_function.py b/vyper/codegen/function_definitions/external_function.py index 65276469e7..b380eab2ce 100644 --- a/vyper/codegen/function_definitions/external_function.py +++ b/vyper/codegen/function_definitions/external_function.py @@ -2,12 +2,19 @@ from vyper.codegen.context import Context, VariableRecord from vyper.codegen.core import get_element_ptr, getpos, make_setter, needs_clamp from vyper.codegen.expr import Expr -from vyper.codegen.function_definitions.utils import get_nonreentrant_lock +from vyper.codegen.function_definitions.common import ( + EntryPointInfo, + ExternalFuncIR, + get_nonreentrant_lock, + initialize_context, + tag_frame_info, +) from vyper.codegen.ir_node import Encoding, IRnode from vyper.codegen.stmt import parse_body from vyper.evm.address_space import CALLDATA, DATA, MEMORY from vyper.semantics.types import TupleT from vyper.semantics.types.function import ContractFunctionT +from vyper.utils import calc_mem_gas # register function args with the local calling context. @@ -51,7 +58,7 @@ def _register_function_args(func_t: ContractFunctionT, context: Context) -> list def _generate_kwarg_handlers( func_t: ContractFunctionT, context: Context -) -> dict[str, tuple[int, IRnode]]: +) -> dict[str, EntryPointInfo]: # generate kwarg handlers. # since they might come in thru calldata or be default, # allocate them in memory and then fill it in based on calldata or default, @@ -126,34 +133,54 @@ def handler_for(calldata_kwargs, default_kwargs): default_kwargs = keyword_args[i:] sig, calldata_min_size, ir_node = handler_for(calldata_kwargs, default_kwargs) - ret[sig] = calldata_min_size, ir_node + assert sig not in ret + ret[sig] = EntryPointInfo(func_t, calldata_min_size, ir_node) sig, calldata_min_size, ir_node = handler_for(keyword_args, []) - ret[sig] = calldata_min_size, ir_node + assert sig not in ret + ret[sig] = EntryPointInfo(func_t, calldata_min_size, ir_node) return ret -def generate_ir_for_external_function(code, func_t, context): +def _adjust_gas_estimate(func_t, common_ir): + # adjust gas estimate to include cost of mem expansion + # frame_size of external function includes all private functions called + # (note: internal functions do not need to adjust gas estimate since + frame_info = func_t._ir_info.frame_info + + mem_expansion_cost = calc_mem_gas(frame_info.mem_used) + common_ir.add_gas_estimate += mem_expansion_cost + func_t._ir_info.gas_estimate = common_ir.gas + + # pass metadata through for venom pipeline: + common_ir.passthrough_metadata["func_t"] = func_t + common_ir.passthrough_metadata["frame_info"] = frame_info + + +def generate_ir_for_external_function(code, compilation_target): # TODO type hints: # def generate_ir_for_external_function( # code: vy_ast.FunctionDef, - # func_t: ContractFunctionT, - # context: Context, + # compilation_target: ModuleT, # ) -> IRnode: """ Return the IR for an external function. Returns IR for the body of the function, handle kwargs and exit the function. Also returns metadata required for `module.py` to construct the selector table. """ + func_t = code._metadata["func_type"] + assert func_t.is_external or func_t.is_constructor # sanity check + + context = initialize_context(func_t, compilation_target, func_t.is_constructor) nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(func_t) # generate handlers for base args and register the variable records handle_base_args = _register_function_args(func_t, context) # generate handlers for kwargs and register the variable records - kwarg_handlers = _generate_kwarg_handlers(func_t, context) + entry_points = _generate_kwarg_handlers(func_t, context) body = ["seq"] # once optional args have been handled, @@ -185,4 +212,8 @@ def generate_ir_for_external_function(code, func_t, context): # besides any kwarg handling func_common_ir = IRnode.from_list(["seq", body, exit_], source_pos=getpos(code)) - return kwarg_handlers, func_common_ir + tag_frame_info(func_t, context) + + _adjust_gas_estimate(func_t, func_common_ir) + + return ExternalFuncIR(entry_points, func_common_ir) diff --git a/vyper/codegen/function_definitions/internal_function.py b/vyper/codegen/function_definitions/internal_function.py index cf01dbdab4..0cf9850b70 100644 --- a/vyper/codegen/function_definitions/internal_function.py +++ b/vyper/codegen/function_definitions/internal_function.py @@ -1,23 +1,25 @@ from vyper import ast as vy_ast -from vyper.codegen.context import Context -from vyper.codegen.function_definitions.utils import get_nonreentrant_lock +from vyper.codegen.function_definitions.common import ( + InternalFuncIR, + get_nonreentrant_lock, + initialize_context, + tag_frame_info, +) from vyper.codegen.ir_node import IRnode from vyper.codegen.stmt import parse_body -from vyper.semantics.types.function import ContractFunctionT def generate_ir_for_internal_function( - code: vy_ast.FunctionDef, func_t: ContractFunctionT, context: Context -) -> IRnode: + code: vy_ast.FunctionDef, module_ctx, is_ctor_context: bool +) -> InternalFuncIR: """ Parse a internal function (FuncDef), and produce full function body. :param func_t: the ContractFunctionT :param code: ast of function - :param context: current calling context + :param compilation_target: current calling context :return: function body in IR """ - # The calling convention is: # Caller fills in argument buffer # Caller provides return address, return buffer on the stack @@ -37,13 +39,19 @@ def generate_ir_for_internal_function( # situation like the following is easy to bork: # x: T[2] = [self.generate_T(), self.generate_T()] - # Get nonreentrant lock + func_t = code._metadata["func_type"] + + # sanity check + assert func_t.is_internal or func_t.is_constructor + + context = initialize_context(func_t, module_ctx, is_ctor_context) for arg in func_t.arguments: # allocate a variable for every arg, setting mutability # to True to allow internal function arguments to be mutable context.new_variable(arg.name, arg.typ, is_mutable=True) + # Get nonreentrant lock nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(func_t) function_entry_label = func_t._ir_info.internal_function_label(context.is_ctor_context) @@ -69,5 +77,13 @@ def generate_ir_for_internal_function( ] ir_node = IRnode.from_list(["seq", body, cleanup_routine]) + + # tag gas estimate and frame info + func_t._ir_info.gas_estimate = ir_node.gas + frame_info = tag_frame_info(func_t, context) + + # pass metadata through for venom pipeline: + ir_node.passthrough_metadata["frame_info"] = frame_info ir_node.passthrough_metadata["func_t"] = func_t - return ir_node + + return InternalFuncIR(ir_node) diff --git a/vyper/codegen/function_definitions/utils.py b/vyper/codegen/function_definitions/utils.py deleted file mode 100644 index f524ec6e88..0000000000 --- a/vyper/codegen/function_definitions/utils.py +++ /dev/null @@ -1,31 +0,0 @@ -from vyper.evm.opcodes import version_check -from vyper.semantics.types.function import StateMutability - - -def get_nonreentrant_lock(func_type): - if not func_type.nonreentrant: - return ["pass"], ["pass"] - - nkey = func_type.reentrancy_key_position.position - - LOAD, STORE = "sload", "sstore" - if version_check(begin="cancun"): - LOAD, STORE = "tload", "tstore" - - if version_check(begin="berlin"): - # any nonzero values would work here (see pricing as of net gas - # metering); these values are chosen so that downgrading to the - # 0,1 scheme (if it is somehow necessary) is safe. - final_value, temp_value = 3, 2 - else: - final_value, temp_value = 0, 1 - - check_notset = ["assert", ["ne", temp_value, [LOAD, nkey]]] - - if func_type.mutability == StateMutability.VIEW: - return [check_notset], [["seq"]] - - else: - pre = ["seq", check_notset, [STORE, nkey, temp_value]] - post = [STORE, nkey, final_value] - return [pre], [post] diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index 98395a6a0c..fef4f23949 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -4,7 +4,10 @@ from vyper.codegen import core, jumptable_utils from vyper.codegen.core import shr -from vyper.codegen.function_definitions import generate_ir_for_function +from vyper.codegen.function_definitions import ( + generate_ir_for_external_function, + generate_ir_for_internal_function, +) from vyper.codegen.ir_node import IRnode from vyper.compiler.settings import _is_debug_mode from vyper.exceptions import CompilerPanic @@ -89,7 +92,7 @@ def _ir_for_fallback_or_ctor(func_ast, *args, **kwargs): callvalue_check = ["assert", ["iszero", "callvalue"]] ret.append(IRnode.from_list(callvalue_check, error_msg="nonpayable check")) - func_ir = generate_ir_for_function(func_ast, *args, **kwargs) + func_ir = generate_ir_for_external_function(func_ast, *args, **kwargs) assert len(func_ir.entry_points) == 1 # add a goto to make the function entry look like other functions @@ -101,7 +104,7 @@ def _ir_for_fallback_or_ctor(func_ast, *args, **kwargs): def _ir_for_internal_function(func_ast, *args, **kwargs): - return generate_ir_for_function(func_ast, *args, **kwargs).func_ir + return generate_ir_for_internal_function(func_ast, *args, **kwargs).func_ir def _generate_external_entry_points(external_functions, module_ctx): @@ -109,7 +112,7 @@ def _generate_external_entry_points(external_functions, module_ctx): sig_of = {} # reverse map from method ids to abi sig for code in external_functions: - func_ir = generate_ir_for_function(code, module_ctx) + func_ir = generate_ir_for_external_function(code, module_ctx) for abi_sig, entry_point in func_ir.entry_points.items(): method_id = method_id_int(abi_sig) assert abi_sig not in entry_points @@ -424,12 +427,13 @@ def _selector_section_linear(external_functions, module_ctx): # take a ModuleT, and generate the runtime and deploy IR def generate_ir_for_module(module_ctx: ModuleT) -> tuple[IRnode, IRnode]: + # XXX: rename `module_ctx` to `compilation_target` # order functions so that each function comes after all of its callees function_defs = _topsort(module_ctx.function_defs) reachable = _globally_reachable_functions(module_ctx.function_defs) runtime_functions = [f for f in function_defs if not _is_constructor(f)] - init_function = next((f for f in function_defs if _is_constructor(f)), None) + init_function = next((f for f in module_ctx.function_defs if _is_constructor(f)), None) internal_functions = [f for f in runtime_functions if _is_internal(f)] @@ -475,24 +479,21 @@ def generate_ir_for_module(module_ctx: ModuleT) -> tuple[IRnode, IRnode]: deploy_code: List[Any] = ["seq"] immutables_len = module_ctx.immutable_section_bytes - if init_function: + if init_function is not None: # cleanly rerun codegen for internal functions with `is_ctor_ctx=True` init_func_t = init_function._metadata["func_type"] ctor_internal_func_irs = [] - internal_functions = [f for f in runtime_functions if _is_internal(f)] - for f in internal_functions: - func_t = f._metadata["func_type"] - if func_t not in init_func_t.reachable_internal_functions: - # unreachable code, delete it - continue - - func_ir = _ir_for_internal_function(f, module_ctx, is_ctor_context=True) + + reachable_from_ctor = init_func_t.reachable_internal_functions + for func_t in reachable_from_ctor: + fn_ast = func_t.ast_def + func_ir = _ir_for_internal_function(fn_ast, module_ctx, is_ctor_context=True) ctor_internal_func_irs.append(func_ir) # generate init_func_ir after callees to ensure they have analyzed # memory usage. # TODO might be cleaner to separate this into an _init_ir helper func - init_func_ir = _ir_for_fallback_or_ctor(init_function, module_ctx, is_ctor_context=True) + init_func_ir = _ir_for_fallback_or_ctor(init_function, module_ctx) # pass the amount of memory allocated for the init function # so that deployment does not clobber while preparing immutables diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 7d4938f287..e6baea75f7 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -144,7 +144,7 @@ def parse_Call(self): return pop_dyn_array(darray, return_popped_item=False) if isinstance(func_type, ContractFunctionT): - if func_type.is_internal: + if func_type.is_internal or func_type.is_constructor: return self_call.ir_for_self_call(self.stmt, self.context) else: return external_call.ir_for_external_call(self.stmt, self.context) diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 5b7decec7b..f7eccdf214 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -152,23 +152,18 @@ def vyper_module(self): return self._generate_ast @cached_property - def _annotated_module(self): - return generate_annotated_ast( - self.vyper_module, self.input_bundle, self.storage_layout_override - ) - - @property def annotated_vyper_module(self) -> vy_ast.Module: - module, storage_layout = self._annotated_module - return module + return generate_annotated_ast(self.vyper_module, self.input_bundle) - @property + @cached_property def storage_layout(self) -> StorageLayout: - module, storage_layout = self._annotated_module - return storage_layout + module_ast = self.annotated_vyper_module + return set_data_positions(module_ast, self.storage_layout_override) @property def global_ctx(self) -> ModuleT: + # ensure storage layout is computed + _ = self.storage_layout return self.annotated_vyper_module._metadata["type"] @cached_property @@ -243,11 +238,7 @@ def blueprint_bytecode(self) -> bytes: return deploy_bytecode + blueprint_bytecode -def generate_annotated_ast( - vyper_module: vy_ast.Module, - input_bundle: InputBundle, - storage_layout_overrides: StorageLayout = None, -) -> tuple[vy_ast.Module, StorageLayout]: +def generate_annotated_ast(vyper_module: vy_ast.Module, input_bundle: InputBundle) -> vy_ast.Module: """ Validates and annotates the Vyper AST. @@ -268,9 +259,7 @@ def generate_annotated_ast( # note: validate_semantics does type inference on the AST validate_semantics(vyper_module, input_bundle) - symbol_tables = set_data_positions(vyper_module, storage_layout_overrides) - - return vyper_module, symbol_tables + return vyper_module def generate_ir_nodes(global_ctx: ModuleT, optimize: OptimizationLevel) -> tuple[IRnode, IRnode]: diff --git a/vyper/evm/address_space.py b/vyper/evm/address_space.py index 85a75c3c23..fcbd4bcf63 100644 --- a/vyper/evm/address_space.py +++ b/vyper/evm/address_space.py @@ -28,14 +28,6 @@ class AddrSpace: # TODO maybe make positional instead of defaulting to None store_op: Optional[str] = None - @property - def word_addressable(self) -> bool: - return self.word_scale == 1 - - @property - def byte_addressable(self) -> bool: - return self.word_scale == 32 - # alternative: # class Memory(AddrSpace): diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 04667aaa59..53ad6f7bb8 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -31,7 +31,7 @@ class _BaseVyperException(Exception): order to display source annotations in the error string. """ - def __init__(self, message="Error Message not found.", *items): + def __init__(self, message="Error Message not found.", *items, hint=None): """ Exception initializer. @@ -47,7 +47,9 @@ def __init__(self, message="Error Message not found.", *items): A single tuple of (lineno, col_offset) is also understood to support the old API, but new exceptions should not use this approach. """ - self.message = message + self._message = message + self._hint = hint + self.lineno = None self.col_offset = None self.annotations = None @@ -77,6 +79,13 @@ def with_annotation(self, *annotations): exc.annotations = annotations return exc + @property + def message(self): + msg = self._message + if self._hint: + msg += f"\n\n (hint: {self._hint})" + return msg + def __str__(self): from vyper import ast as vy_ast from vyper.utils import annotate_source_code @@ -131,7 +140,7 @@ def __str__(self): annotation_list.append(node_msg) annotation_msg = "\n".join(annotation_list) - return f"{self.message}\n{annotation_msg}" + return f"{self.message}\n\n{annotation_msg}" class VyperException(_BaseVyperException): @@ -252,6 +261,14 @@ class ImmutableViolation(VyperException): """Modifying an immutable variable, constant, or definition.""" +class InitializerException(VyperException): + """An issue with initializing/constructing a module""" + + +class BorrowException(VyperException): + """An issue with borrowing/using a module""" + + class StateAccessViolation(VyperException): """Violating the mutability of a function definition.""" @@ -369,7 +386,7 @@ def tag_exceptions(node, fallback_exception_type=CompilerPanic, note=None): except _BaseVyperException as e: if not e.annotations and not e.lineno: tb = e.__traceback__ - raise e.with_annotation(node).with_traceback(tb) + raise e.with_annotation(node).with_traceback(tb) from None raise e from None except Exception as e: tb = e.__traceback__ diff --git a/vyper/semantics/analysis/__init__.py b/vyper/semantics/analysis/__init__.py index 7b52a68e92..e23b2d2aa4 100644 --- a/vyper/semantics/analysis/__init__.py +++ b/vyper/semantics/analysis/__init__.py @@ -1,4 +1,4 @@ from .. import types # break a dependency cycle. -from .module import validate_semantics +from .global_ import validate_semantics __all__ = ["validate_semantics"] diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index bb6d9ad9f7..49b867aae5 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -1,84 +1,29 @@ import enum from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, Optional, Union from vyper import ast as vy_ast from vyper.compiler.input_bundle import InputBundle -from vyper.exceptions import ( - CompilerPanic, - ImmutableViolation, - StateAccessViolation, - VyperInternalException, -) +from vyper.exceptions import CompilerPanic, StructureException from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import VyperType +from vyper.utils import OrderedSet, StringEnum if TYPE_CHECKING: from vyper.semantics.types.module import InterfaceT, ModuleT -class _StringEnum(enum.Enum): - @staticmethod - def auto(): - return enum.auto() +class FunctionVisibility(StringEnum): + EXTERNAL = enum.auto() + INTERNAL = enum.auto() + DEPLOY = enum.auto() - # Must be first, or else won't work, specifies what .value is - def _generate_next_value_(name, start, count, last_values): - return name.lower() - # Override ValueError with our own internal exception - @classmethod - def _missing_(cls, value): - raise VyperInternalException(f"{value} is not a valid {cls.__name__}") - - @classmethod - def is_valid_value(cls, value: str) -> bool: - return value in set(o.value for o in cls) - - @classmethod - def options(cls) -> List["_StringEnum"]: - return list(cls) - - @classmethod - def values(cls) -> List[str]: - return [v.value for v in cls.options()] - - # Comparison operations - def __eq__(self, other: object) -> bool: - if not isinstance(other, self.__class__): - raise CompilerPanic("Can only compare like types.") - return self is other - - # Python normally does __ne__(other) ==> not self.__eq__(other) - - def __lt__(self, other: object) -> bool: - if not isinstance(other, self.__class__): - raise CompilerPanic("Can only compare like types.") - options = self.__class__.options() - return options.index(self) < options.index(other) # type: ignore - - def __le__(self, other: object) -> bool: - return self.__eq__(other) or self.__lt__(other) - - def __gt__(self, other: object) -> bool: - return not self.__le__(other) - - def __ge__(self, other: object) -> bool: - return self.__eq__(other) or self.__gt__(other) - - -class FunctionVisibility(_StringEnum): - # TODO: these can just be enum.auto() right? - EXTERNAL = _StringEnum.auto() - INTERNAL = _StringEnum.auto() - - -class StateMutability(_StringEnum): - # TODO: these can just be enum.auto() right? - PURE = _StringEnum.auto() - VIEW = _StringEnum.auto() - NONPAYABLE = _StringEnum.auto() - PAYABLE = _StringEnum.auto() +class StateMutability(StringEnum): + PURE = enum.auto() + VIEW = enum.auto() + NONPAYABLE = enum.auto() + PAYABLE = enum.auto() @classmethod def from_abi(cls, abi_dict: Dict) -> "StateMutability": @@ -103,71 +48,40 @@ def from_abi(cls, abi_dict: Dict) -> "StateMutability": # and variables) and Constancy (in codegen). context.Constancy can/should # probably be refactored away though as those kinds of checks should be done # during analysis. -class Modifiability(enum.IntEnum): - # is writeable/can result in arbitrary state or memory changes - MODIFIABLE = enum.auto() - - # could potentially add more fine-grained here as needed, like - # CONSTANT_AFTER_DEPLOY, TX_CONSTANT, BLOCK_CONSTANT, etc. +class Modifiability(StringEnum): + # compile-time / always constant + CONSTANT = enum.auto() # things that are constant within the current message call, including # block.*, msg.*, tx.* and immutables RUNTIME_CONSTANT = enum.auto() - # compile-time / always constant - CONSTANT = enum.auto() - - -class DataPosition: - _location: DataLocation - - -class CalldataOffset(DataPosition): - __slots__ = ("dynamic_offset", "static_offset") - _location = DataLocation.CALLDATA - - def __init__(self, static_offset, dynamic_offset=None): - self.static_offset = static_offset - self.dynamic_offset = dynamic_offset - - def __repr__(self): - if self.dynamic_offset is not None: - return f"" - else: - return f"" - - -class MemoryOffset(DataPosition): - __slots__ = ("offset",) - _location = DataLocation.MEMORY - - def __init__(self, offset): - self.offset = offset - - def __repr__(self): - return f"" - - -class StorageSlot(DataPosition): - __slots__ = ("position",) - _location = DataLocation.STORAGE + # could potentially add more fine-grained here as needed, like + # CONSTANT_AFTER_DEPLOY, TX_CONSTANT, BLOCK_CONSTANT, etc. - def __init__(self, position): - self.position = position + # is writeable/can result in arbitrary state or memory changes + MODIFIABLE = enum.auto() - def __repr__(self): - return f"" + @classmethod + def from_state_mutability(cls, mutability: StateMutability): + if mutability == StateMutability.PURE: + return cls.CONSTANT + if mutability == StateMutability.VIEW: + return cls.RUNTIME_CONSTANT + # sanity check in case more StateMutability levels are added in the future + assert mutability in (StateMutability.PAYABLE, StateMutability.NONPAYABLE) + return cls.MODIFIABLE -class CodeOffset(DataPosition): - __slots__ = ("offset",) - _location = DataLocation.CODE +@dataclass +class VarOffset: + position: int - def __init__(self, offset): - self.offset = offset - def __repr__(self): - return f"" +class ModuleOwnership(StringEnum): + NO_OWNERSHIP = enum.auto() # readable + USES = enum.auto() # writeable + INITIALIZES = enum.auto() # initializes # base class for things that are the "result" of analysis @@ -178,6 +92,9 @@ class AnalysisResult: @dataclass class ModuleInfo(AnalysisResult): module_t: "ModuleT" + alias: str + ownership: ModuleOwnership = ModuleOwnership.NO_OWNERSHIP + ownership_decl: Optional[vy_ast.VyperNode] = None @property def module_node(self): @@ -188,6 +105,16 @@ def module_node(self): def typ(self): return self.module_t + def set_ownership(self, module_ownership: ModuleOwnership, node: Optional[vy_ast.VyperNode]): + if self.ownership != ModuleOwnership.NO_OWNERSHIP: + raise StructureException( + f"ownership already set to `{self.ownership}`", node, self.ownership_decl + ) + self.ownership = module_ownership + + def __hash__(self): + return hash(id(self.module_t)) + @dataclass class ImportInfo(AnalysisResult): @@ -199,6 +126,21 @@ class ImportInfo(AnalysisResult): node: vy_ast.VyperNode +# analysis result of InitializesDecl +@dataclass +class InitializesInfo(AnalysisResult): + module_info: ModuleInfo + dependencies: list[ModuleInfo] + node: Optional[vy_ast.VyperNode] = None + + +# analysis result of UsesDecl +@dataclass +class UsesInfo(AnalysisResult): + used_modules: list[ModuleInfo] + node: Optional[vy_ast.VyperNode] = None + + @dataclass class VarInfo: """ @@ -221,22 +163,21 @@ def __hash__(self): return hash(id(self)) def __post_init__(self): + self.position = None self._modification_count = 0 - def set_position(self, position: DataPosition) -> None: - if hasattr(self, "position"): + def set_position(self, position: VarOffset) -> None: + if self.position is not None: raise CompilerPanic("Position was already assigned") - if self.location != position._location: - if self.location == DataLocation.UNSET: - self.location = position._location - elif self.is_transient and position._location == DataLocation.STORAGE: - # CMC 2023-12-31 - use same allocator for storage and transient - # for now, this should be refactored soon. - pass - else: - raise CompilerPanic("Incompatible locations") + assert isinstance(position, VarOffset) # sanity check self.position = position + def is_module_variable(self): + return self.location not in (DataLocation.UNSET, DataLocation.MEMORY) + + def get_size(self) -> int: + return self.typ.get_size_in(self.location) + @property def is_transient(self): return self.location == DataLocation.TRANSIENT @@ -252,6 +193,17 @@ def is_constant(self): return res +@dataclass(frozen=True) +class VarAccess: + variable: VarInfo + attrs: tuple[str, ...] + + def contains(self, other): + # VarAccess("v", ("a")) `contains` VarAccess("v", ("a", "b", "c")) + sub_attrs = other.attrs[: len(self.attrs)] + return self.variable == other.variable and sub_attrs == self.attrs + + @dataclass class ExprInfo: """ @@ -260,8 +212,10 @@ class ExprInfo: typ: VyperType var_info: Optional[VarInfo] = None + module_info: Optional[ModuleInfo] = None location: DataLocation = DataLocation.UNSET modifiability: Modifiability = Modifiability.MODIFIABLE + attr: Optional[str] = None def __post_init__(self): should_match = ("typ", "location", "modifiability") @@ -270,65 +224,35 @@ def __post_init__(self): if getattr(self.var_info, attr) != getattr(self, attr): raise CompilerPanic("Bad analysis: non-matching {attr}: {self}") + self._writes: OrderedSet[VarAccess] = OrderedSet() + self._reads: OrderedSet[VarAccess] = OrderedSet() + @classmethod - def from_varinfo(cls, var_info: VarInfo) -> "ExprInfo": + def from_varinfo(cls, var_info: VarInfo, **kwargs) -> "ExprInfo": return cls( var_info.typ, var_info=var_info, location=var_info.location, modifiability=var_info.modifiability, + **kwargs, ) @classmethod - def from_moduleinfo(cls, module_info: ModuleInfo) -> "ExprInfo": - return cls(module_info.module_t) + def from_moduleinfo(cls, module_info: ModuleInfo, **kwargs) -> "ExprInfo": + modifiability = Modifiability.RUNTIME_CONSTANT + if module_info.ownership >= ModuleOwnership.USES: + modifiability = Modifiability.MODIFIABLE - def copy_with_type(self, typ: VyperType) -> "ExprInfo": + return cls( + module_info.module_t, module_info=module_info, modifiability=modifiability, **kwargs + ) + + def copy_with_type(self, typ: VyperType, **kwargs) -> "ExprInfo": """ Return a copy of the ExprInfo but with the type set to something else """ to_copy = ("location", "modifiability") fields = {k: getattr(self, k) for k in to_copy} - return self.__class__(typ=typ, **fields) - - def validate_modification(self, node: vy_ast.VyperNode, mutability: StateMutability) -> None: - """ - Validate an attempt to modify this value. - - Raises if the value is a constant or involves an invalid operation. - - Arguments - --------- - node : Assign | AugAssign | Call - Vyper ast node of the modifying action. - mutability: StateMutability - The mutability of the context (e.g., pure function) we are currently in - """ - if mutability <= StateMutability.VIEW and self.location == DataLocation.STORAGE: - raise StateAccessViolation( - f"Cannot modify storage in a {mutability.value} function", node - ) - - if self.location == DataLocation.CALLDATA: - raise ImmutableViolation("Cannot write to calldata", node) - - if self.modifiability == Modifiability.RUNTIME_CONSTANT: - if self.location == DataLocation.CODE: - if node.get_ancestor(vy_ast.FunctionDef).get("name") != "__init__": - raise ImmutableViolation("Immutable value cannot be written to", node) - - # special handling for immutable variables in the ctor - # TODO: we probably want to remove this restriction. - if self.var_info._modification_count: # type: ignore - raise ImmutableViolation( - "Immutable value cannot be modified after assignment", node - ) - self.var_info._modification_count += 1 # type: ignore - else: - raise ImmutableViolation("Environment variable cannot be written to", node) - - if self.modifiability == Modifiability.CONSTANT: - raise ImmutableViolation("Constant value cannot be written to", node) - - if isinstance(node, vy_ast.AugAssign): - self.typ.validate_numeric_op(node) + for t in to_copy: + assert t not in kwargs + return self.__class__(typ=typ, **fields, **kwargs) diff --git a/vyper/semantics/analysis/constant_folding.py b/vyper/semantics/analysis/constant_folding.py index bfcc473d09..3522383167 100644 --- a/vyper/semantics/analysis/constant_folding.py +++ b/vyper/semantics/analysis/constant_folding.py @@ -113,7 +113,7 @@ def visit_Attribute(self, node) -> vy_ast.ExprNode: varinfo = module_t.get_member(node.attr, node) return varinfo.decl_node.value.get_folded_value() - except (VyperException, AttributeError): + except (VyperException, AttributeError, KeyError): raise UnfoldableNode("not a module") def visit_UnaryOp(self, node): diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index 88679a4b09..604bc6b594 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -1,11 +1,12 @@ -# TODO this module doesn't really belong in "validation" -from typing import Dict, List +from collections import defaultdict +from typing import Generic, TypeVar from vyper import ast as vy_ast -from vyper.exceptions import StorageLayoutException -from vyper.semantics.analysis.base import CodeOffset, StorageSlot +from vyper.evm.opcodes import version_check +from vyper.exceptions import CompilerPanic, StorageLayoutException +from vyper.semantics.analysis.base import VarOffset +from vyper.semantics.data_locations import DataLocation from vyper.typing import StorageLayout -from vyper.utils import ceil32 def set_data_positions( @@ -20,24 +21,76 @@ def set_data_positions( vyper_module : vy_ast.Module Top-level Vyper AST node that has already been annotated with type data. """ - code_offsets = set_code_offsets(vyper_module) - storage_slots = ( - set_storage_slots_with_overrides(vyper_module, storage_layout_overrides) - if storage_layout_overrides is not None - else set_storage_slots(vyper_module) - ) + if storage_layout_overrides is not None: + # extract code layout with no overrides + code_offsets = _allocate_layout_r(vyper_module, immutables_only=True)["code_layout"] + storage_slots = set_storage_slots_with_overrides(vyper_module, storage_layout_overrides) + return {"storage_layout": storage_slots, "code_layout": code_offsets} - return {"storage_layout": storage_slots, "code_layout": code_offsets} + ret = _allocate_layout_r(vyper_module) + assert isinstance(ret, defaultdict) + return dict(ret) # convert back to dict -class StorageAllocator: +_T = TypeVar("_T") +_K = TypeVar("_K") + + +class InsertableOnceDict(Generic[_T, _K], dict[_T, _K]): + def __setitem__(self, k, v): + if k in self: + raise ValueError(f"{k} is already in dict!") + super().__setitem__(k, v) + + +class SimpleAllocator: + def __init__(self, max_slot: int = 2**256, starting_slot: int = 0): + # Allocate storage slots from 0 + # note storage is word-addressable, not byte-addressable + self._slot = starting_slot + self._max_slot = max_slot + + def allocate_slot(self, n, var_name, node=None): + ret = self._slot + if self._slot + n >= self._max_slot: + raise StorageLayoutException( + f"Invalid storage slot, tried to allocate" + f" slots {self._slot} through {self._slot + n}", + node, + ) + self._slot += n + return ret + + +class Allocators: + storage_allocator: SimpleAllocator + transient_storage_allocator: SimpleAllocator + immutables_allocator: SimpleAllocator + + def __init__(self): + self.storage_allocator = SimpleAllocator(max_slot=2**256) + self.transient_storage_allocator = SimpleAllocator(max_slot=2**256) + self.immutables_allocator = SimpleAllocator(max_slot=0x6000) + + def get_allocator(self, location: DataLocation): + if location == DataLocation.STORAGE: + return self.storage_allocator + if location == DataLocation.TRANSIENT: + return self.transient_storage_allocator + if location == DataLocation.CODE: + return self.immutables_allocator + + raise CompilerPanic("unreachable") # pragma: nocover + + +class OverridingStorageAllocator: """ Keep track of which storage slots have been used. If there is a collision of storage slots, this will raise an error and fail to compile """ def __init__(self): - self.occupied_slots: Dict[int, str] = {} + self.occupied_slots: dict[int, str] = {} def reserve_slot_range(self, first_slot: int, n_slots: int, var_name: str) -> None: """ @@ -48,7 +101,7 @@ def reserve_slot_range(self, first_slot: int, n_slots: int, var_name: str) -> No list_to_check = [x + first_slot for x in range(n_slots)] self._reserve_slots(list_to_check, var_name) - def _reserve_slots(self, slots: List[int], var_name: str) -> None: + def _reserve_slots(self, slots: list[int], var_name: str) -> None: for slot in slots: self._reserve_slot(slot, var_name) @@ -70,12 +123,13 @@ def set_storage_slots_with_overrides( vyper_module: vy_ast.Module, storage_layout_overrides: StorageLayout ) -> StorageLayout: """ - Parse module-level Vyper AST to calculate the layout of storage variables. + Set storage layout given a layout override file. Returns the layout as a dict of variable name -> variable info + (Doesn't handle modules, or transient storage) """ - ret: Dict[str, Dict] = {} - reserved_slots = StorageAllocator() + ret: InsertableOnceDict[str, dict] = InsertableOnceDict() + reserved_slots = OverridingStorageAllocator() # Search through function definitions to find non-reentrant functions for node in vyper_module.get_children(vy_ast.FunctionDef): @@ -90,7 +144,7 @@ def set_storage_slots_with_overrides( # re-entrant key was already identified if variable_name in ret: _slot = ret[variable_name]["slot"] - type_.set_reentrancy_key_position(StorageSlot(_slot)) + type_.set_reentrancy_key_position(VarOffset(_slot)) continue # Expect to find this variable within the storage layout override @@ -100,7 +154,7 @@ def set_storage_slots_with_overrides( # from using the same slot reserved_slots.reserve_slot_range(reentrant_slot, 1, variable_name) - type_.set_reentrancy_key_position(StorageSlot(reentrant_slot)) + type_.set_reentrancy_key_position(VarOffset(reentrant_slot)) ret[variable_name] = {"type": "nonreentrant lock", "slot": reentrant_slot} else: @@ -125,7 +179,7 @@ def set_storage_slots_with_overrides( # Ensure that all required storage slots are reserved, and prevents other variables # from using these slots reserved_slots.reserve_slot_range(var_slot, storage_length, node.target.id) - varinfo.set_position(StorageSlot(var_slot)) + varinfo.set_position(VarOffset(var_slot)) ret[node.target.id] = {"type": str(varinfo.typ), "slot": var_slot} else: @@ -138,105 +192,108 @@ def set_storage_slots_with_overrides( return ret -class SimpleStorageAllocator: - def __init__(self, starting_slot: int = 0): - self._slot = starting_slot +def _get_allocatable(vyper_module: vy_ast.Module) -> list[vy_ast.VyperNode]: + allocable = (vy_ast.InitializesDecl, vy_ast.VariableDecl) + return [node for node in vyper_module.body if isinstance(node, allocable)] - def allocate_slot(self, n, var_name): - ret = self._slot - if self._slot + n >= 2**256: - raise StorageLayoutException( - f"Invalid storage slot for var {var_name}, tried to allocate" - f" slots {self._slot} through {self._slot + n}" - ) - self._slot += n - return ret +def get_reentrancy_key_location() -> DataLocation: + if version_check(begin="cancun"): + return DataLocation.TRANSIENT + return DataLocation.STORAGE -def set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout: + +_LAYOUT_KEYS = { + DataLocation.CODE: "code_layout", + DataLocation.TRANSIENT: "transient_storage_layout", + DataLocation.STORAGE: "storage_layout", +} + + +def _allocate_layout_r( + vyper_module: vy_ast.Module, allocators: Allocators = None, immutables_only=False +) -> StorageLayout: """ Parse module-level Vyper AST to calculate the layout of storage variables. Returns the layout as a dict of variable name -> variable info """ - # Allocate storage slots from 0 - # note storage is word-addressable, not byte-addressable - allocator = SimpleStorageAllocator() + if allocators is None: + allocators = Allocators() - ret: Dict[str, Dict] = {} + ret: defaultdict[str, InsertableOnceDict[str, dict]] = defaultdict(InsertableOnceDict) for node in vyper_module.get_children(vy_ast.FunctionDef): + if immutables_only: + break + type_ = node._metadata["func_type"] if type_.nonreentrant is None: continue variable_name = f"nonreentrant.{type_.nonreentrant}" + reentrancy_key_location = get_reentrancy_key_location() + layout_key = _LAYOUT_KEYS[reentrancy_key_location] # a nonreentrant key can appear many times in a module but it # only takes one slot. after the first time we see it, do not # increment the storage slot. - if variable_name in ret: - _slot = ret[variable_name]["slot"] - type_.set_reentrancy_key_position(StorageSlot(_slot)) + if variable_name in ret[layout_key]: + _slot = ret[layout_key][variable_name]["slot"] + type_.set_reentrancy_key_position(VarOffset(_slot)) continue # TODO use one byte - or bit - per reentrancy key # requires either an extra SLOAD or caching the value of the # location in memory at entrance - slot = allocator.allocate_slot(1, variable_name) + allocator = allocators.get_allocator(reentrancy_key_location) + slot = allocator.allocate_slot(1, variable_name, node) - type_.set_reentrancy_key_position(StorageSlot(slot)) + type_.set_reentrancy_key_position(VarOffset(slot)) # TODO this could have better typing but leave it untyped until # we nail down the format better - ret[variable_name] = {"type": "nonreentrant lock", "slot": slot} - - for node in vyper_module.get_children(vy_ast.VariableDecl): - # skip non-storage variables - if node.is_constant or node.is_immutable: + ret[layout_key][variable_name] = {"type": "nonreentrant lock", "slot": slot} + + for node in _get_allocatable(vyper_module): + if isinstance(node, vy_ast.InitializesDecl): + module_info = node._metadata["initializes_info"].module_info + module_layout = _allocate_layout_r(module_info.module_node, allocators) + module_alias = module_info.alias + for layout_key in module_layout.keys(): + assert layout_key in _LAYOUT_KEYS.values() + ret[layout_key][module_alias] = module_layout[layout_key] continue + assert isinstance(node, vy_ast.VariableDecl) + # skip non-storage variables varinfo = node.target._metadata["varinfo"] - type_ = varinfo.typ - - # CMC 2021-07-23 note that HashMaps get assigned a slot here. - # I'm not sure if it's safe to avoid allocating that slot - # for HashMaps because downstream code might use the slot - # ID as a salt. - n_slots = type_.storage_size_in_words - slot = allocator.allocate_slot(n_slots, node.target.id) - - varinfo.set_position(StorageSlot(slot)) - - # this could have better typing but leave it untyped until - # we understand the use case better - ret[node.target.id] = {"type": str(type_), "slot": slot} - - return ret - - -def set_calldata_offsets(fn_node: vy_ast.FunctionDef) -> None: - pass + if not varinfo.is_module_variable(): + continue + location = varinfo.location + if immutables_only and location != DataLocation.CODE: + continue -def set_memory_offsets(fn_node: vy_ast.FunctionDef) -> None: - pass + allocator = allocators.get_allocator(location) + size = varinfo.get_size() + # CMC 2021-07-23 note that HashMaps get assigned a slot here + # using the same allocator (even though there is not really + # any risk of physical overlap) + offset = allocator.allocate_slot(size, node.target.id, node) -def set_code_offsets(vyper_module: vy_ast.Module) -> Dict: - ret = {} - offset = 0 + varinfo.set_position(VarOffset(offset)) - for node in vyper_module.get_children(vy_ast.VariableDecl, filters={"is_immutable": True}): - varinfo = node.target._metadata["varinfo"] + layout_key = _LAYOUT_KEYS[location] type_ = varinfo.typ - varinfo.set_position(CodeOffset(offset)) - - len_ = ceil32(type_.size_in_bytes) - # this could have better typing but leave it untyped until # we understand the use case better - ret[node.target.id] = {"type": str(type_), "offset": offset, "length": len_} - - offset += len_ + if location == DataLocation.CODE: + item = {"type": str(type_), "length": size, "offset": offset} + elif location in (DataLocation.STORAGE, DataLocation.TRANSIENT): + item = {"type": str(type_), "slot": offset} + else: # pragma: nocover + raise CompilerPanic("unreachable") + ret[layout_key][node.target.id] = item return ret diff --git a/vyper/semantics/analysis/global_.py b/vyper/semantics/analysis/global_.py new file mode 100644 index 0000000000..92cdf35c5d --- /dev/null +++ b/vyper/semantics/analysis/global_.py @@ -0,0 +1,80 @@ +from collections import defaultdict + +from vyper.exceptions import ExceptionList, InitializerException +from vyper.semantics.analysis.base import InitializesInfo, UsesInfo +from vyper.semantics.analysis.import_graph import ImportGraph +from vyper.semantics.analysis.module import validate_module_semantics_r +from vyper.semantics.types.module import ModuleT + + +def validate_semantics(module_ast, input_bundle, is_interface=False) -> ModuleT: + ret = validate_module_semantics_r(module_ast, input_bundle, ImportGraph(), is_interface) + + _validate_global_initializes_constraint(ret) + + return ret + + +def _collect_used_modules_r(module_t): + ret: defaultdict[ModuleT, list[UsesInfo]] = defaultdict(list) + + for uses_decl in module_t.uses_decls: + for used_module in uses_decl._metadata["uses_info"].used_modules: + ret[used_module.module_t].append(uses_decl) + + # recurse + used_modules = _collect_used_modules_r(used_module.module_t) + for k, v in used_modules.items(): + ret[k].extend(v) + + # also recurse into modules used by initialized modules + for i in module_t.initialized_modules: + used_modules = _collect_used_modules_r(i.module_info.module_t) + for k, v in used_modules.items(): + ret[k].extend(v) + + return ret + + +def _collect_initialized_modules_r(module_t, seen=None): + seen: dict[ModuleT, InitializesInfo] = seen or {} + + # list of InitializedInfo + initialized_infos = module_t.initialized_modules + + for i in initialized_infos: + initialized_module_t = i.module_info.module_t + if initialized_module_t in seen: + seen_nodes = (i.node, seen[initialized_module_t].node) + raise InitializerException(f"`{i.module_info.alias}` initialized twice!", *seen_nodes) + seen[initialized_module_t] = i + + _collect_initialized_modules_r(initialized_module_t, seen) + + return seen + + +# validate that each module which is `used` in the import graph is +# `initialized`. +def _validate_global_initializes_constraint(module_t: ModuleT): + all_used_modules = _collect_used_modules_r(module_t) + all_initialized_modules = _collect_initialized_modules_r(module_t) + + err_list = ExceptionList() + + for u, uses in all_used_modules.items(): + if u not in all_initialized_modules: + found_module = module_t.find_module_info(u) + if found_module is not None: + hint = f"add `initializes: {found_module.alias}` to the top level of " + hint += "your main contract" + else: + # CMC 2024-02-06 is this actually reachable? + hint = f"ensure `{module_t}` is imported in your main contract!" + err_list.append( + InitializerException( + f"module `{u}` is used but never initialized!", *uses, hint=hint + ) + ) + + err_list.raise_if_not_empty() diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 91cc0ebdf8..39a1c59290 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -1,8 +1,12 @@ +# CMC 2024-02-03 TODO: split me into function.py and expr.py + +import contextlib from typing import Optional from vyper import ast as vy_ast from vyper.ast.validation import validate_call_args from vyper.exceptions import ( + CallViolation, ExceptionList, FunctionDeclarationException, ImmutableViolation, @@ -16,7 +20,13 @@ VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import Modifiability, VarInfo +from vyper.semantics.analysis.base import ( + Modifiability, + ModuleInfo, + ModuleOwnership, + VarAccess, + VarInfo, +) from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.utils import ( get_common_types, @@ -54,20 +64,34 @@ def validate_functions(vy_module: vy_ast.Module) -> None: """Analyzes a vyper ast and validates the function bodies""" - err_list = ExceptionList() - namespace = get_namespace() + for node in vy_module.get_children(vy_ast.FunctionDef): - with namespace.enter_scope(): - try: - analyzer = FunctionNodeVisitor(vy_module, node, namespace) - analyzer.analyze() - except VyperException as e: - err_list.append(e) + _validate_function_r(vy_module, node, err_list) err_list.raise_if_not_empty() +def _validate_function_r( + vy_module: vy_ast.Module, node: vy_ast.FunctionDef, err_list: ExceptionList +): + func_t = node._metadata["func_type"] + + for call_t in func_t.called_functions: + if isinstance(call_t, ContractFunctionT): + assert isinstance(call_t.ast_def, vy_ast.FunctionDef) # help mypy + _validate_function_r(vy_module, call_t.ast_def, err_list) + + namespace = get_namespace() + + try: + with namespace.enter_scope(): + analyzer = FunctionAnalyzer(vy_module, node, namespace) + analyzer.analyze() + except VyperException as e: + err_list.append(e) + + # finds the terminus node for a list of nodes. # raises an exception if any nodes are unreachable def find_terminating_node(node_list: list) -> Optional[vy_ast.VyperNode]: @@ -97,36 +121,6 @@ def find_terminating_node(node_list: list) -> Optional[vy_ast.VyperNode]: return ret -def _check_iterator_modification( - target_node: vy_ast.VyperNode, search_node: vy_ast.VyperNode -) -> Optional[vy_ast.VyperNode]: - similar_nodes = [ - n - for n in search_node.get_descendants(type(target_node)) - if vy_ast.compare_nodes(target_node, n) - ] - - for node in similar_nodes: - # raise if the node is the target of an assignment statement - assign_node = node.get_ancestor((vy_ast.Assign, vy_ast.AugAssign)) - # note the use of get_descendants() blocks statements like - # self.my_array[i] = x - if assign_node and node in assign_node.target.get_descendants(include_self=True): - return node - - attr_node = node.get_ancestor(vy_ast.Attribute) - # note the use of get_descendants() blocks statements like - # self.my_array[i].append(x) - if ( - attr_node is not None - and node in attr_node.value.get_descendants(include_self=True) - and attr_node.attr in ("append", "pop", "extend") - ): - return node - - return None - - # helpers def _validate_address_code(node: vy_ast.Attribute, value_type: VyperType) -> None: if isinstance(value_type, AddressT) and node.attr == "code": @@ -181,7 +175,63 @@ def _validate_self_reference(node: vy_ast.Name) -> None: raise StateAccessViolation("not allowed to query self in pure functions", node) -class FunctionNodeVisitor(VyperNodeVisitorBase): +# analyse the variable access for the attribute chain for a node +# e.x. `x` will return varinfo for `x` +# `module.foo` will return VarAccess for `module.foo` +# `self.my_struct.x.y` will return VarAccess for `self.my_struct.x.y` +def _get_variable_access(node: vy_ast.ExprNode) -> Optional[VarAccess]: + attrs: list[str] = [] + info = get_expr_info(node) + + while info.var_info is None: + if not isinstance(node, (vy_ast.Subscript, vy_ast.Attribute)): + # it's something like a literal + return None + + if isinstance(node, vy_ast.Subscript): + # Subscript is an analysis barrier + # we cannot analyse if `x.y[ix1].z` overlaps with `x.y[ix2].z`. + attrs.clear() + + if (attr := info.attr) is not None: + attrs.append(attr) + + assert isinstance(node, (vy_ast.Subscript, vy_ast.Attribute)) # help mypy + node = node.value + info = get_expr_info(node) + + # ignore `self.` as it interferes with VarAccess comparison across modules + if len(attrs) > 0 and attrs[-1] == "self": + attrs.pop() + attrs.reverse() + + return VarAccess(info.var_info, tuple(attrs)) + + +# get the chain of modules, e.g. +# mod1.mod2.x.y -> [ModuleInfo(mod1), ModuleInfo(mod2)] +# CMC 2024-02-12 note that the Attribute/Subscript traversal in this and +# _get_variable_access() are a bit gross and could probably +# be refactored into data on ExprInfo. +def _get_module_chain(node: vy_ast.ExprNode) -> list[ModuleInfo]: + ret: list[ModuleInfo] = [] + info = get_expr_info(node) + + while True: + if info.module_info is not None: + ret.append(info.module_info) + + if not isinstance(node, (vy_ast.Subscript, vy_ast.Attribute)): + break + + node = node.value + info = get_expr_info(node) + + ret.reverse() + return ret + + +class FunctionAnalyzer(VyperNodeVisitorBase): ignored_types = (vy_ast.Pass,) scope_name = "function" @@ -192,9 +242,18 @@ def __init__( self.fn_node = fn_node self.namespace = namespace self.func = fn_node._metadata["func_type"] - self.expr_visitor = ExprVisitor(self.func) + self.expr_visitor = ExprVisitor(self) + + self.loop_variables: list[Optional[VarAccess]] = [] def analyze(self): + if self.func.analysed: + return + + # mark seen before analysing, if analysis throws an exception which + # gets caught, we don't want to analyse again. + self.func.mark_analysed() + # allow internal function params to be mutable if self.func.is_internal: location, modifiability = (DataLocation.MEMORY, Modifiability.MODIFIABLE) @@ -223,6 +282,14 @@ def analyze(self): for kwarg in self.func.keyword_args: self.expr_visitor.visit(kwarg.default_value, kwarg.typ) + @contextlib.contextmanager + def enter_for_loop(self, varaccess: Optional[VarAccess]): + self.loop_variables.append(varaccess) + try: + yield + finally: + self.loop_variables.pop() + def visit(self, node): super().visit(node) @@ -270,21 +337,91 @@ def _assign_helper(self, node): raise StructureException("Right-hand side of assignment cannot be a tuple", node.value) target = get_expr_info(node.target) - if isinstance(target.typ, HashMapT): - raise StructureException( - "Left-hand side of assignment cannot be a HashMap without a key", node - ) - target.validate_modification(node, self.func.mutability) + # check mutability of the function + self._handle_modification(node.target) self.expr_visitor.visit(node.value, target.typ) self.expr_visitor.visit(node.target, target.typ) + def _handle_modification(self, target: vy_ast.ExprNode): + if isinstance(target, vy_ast.Tuple): + for item in target.elements: + self._handle_modification(item) + return + + # check a modification of `target`. validate the modification is + # valid, and log the modification in relevant data structures. + func_t = self.func + info = get_expr_info(target) + + if isinstance(info.typ, HashMapT): + raise StructureException( + "Left-hand side of assignment cannot be a HashMap without a key" + ) + + if ( + info.location in (DataLocation.STORAGE, DataLocation.TRANSIENT) + and func_t.mutability <= StateMutability.VIEW + ): + raise StateAccessViolation( + f"Cannot modify {info.location} variable in a {func_t.mutability} function" + ) + + if info.location == DataLocation.CALLDATA: + raise ImmutableViolation("Cannot write to calldata") + + if info.modifiability == Modifiability.RUNTIME_CONSTANT: + if info.location == DataLocation.CODE: + if not func_t.is_constructor: + raise ImmutableViolation("Immutable value cannot be written to") + + # handle immutables + if info.var_info is not None: # don't handle complex (struct,array) immutables + # special handling for immutable variables in the ctor + # TODO: maybe we want to remove this restriction. + if info.var_info._modification_count != 0: + raise ImmutableViolation( + "Immutable value cannot be modified after assignment" + ) + info.var_info._modification_count += 1 + else: + raise ImmutableViolation("Environment variable cannot be written to") + + if info.modifiability == Modifiability.CONSTANT: + raise ImmutableViolation("Constant value cannot be written to.") + + var_access = _get_variable_access(target) + assert var_access is not None + + info._writes.add(var_access) + + def _check_module_use(self, target: vy_ast.ExprNode): + module_infos = _get_module_chain(target) + + if len(module_infos) == 0: + return + + for module_info in module_infos: + if module_info.ownership < ModuleOwnership.USES: + msg = f"Cannot access `{module_info.alias}` state!" + hint = f"add `uses: {module_info.alias}` or " + hint += f"`initializes: {module_info.alias}` as " + hint += "a top-level statement to your contract" + raise ImmutableViolation(msg, hint=hint) + + # the leftmost- referenced module + root_module_info = module_infos[0] + + # log the access + self.func.mark_used_module(root_module_info) + def visit_Assign(self, node): self._assign_helper(node) def visit_AugAssign(self, node): self._assign_helper(node) + node.target._expr_info.typ.validate_numeric_op(node) def visit_Break(self, node): for_node = node.get_ancestor(vy_ast.For) @@ -309,35 +446,13 @@ def visit_Expr(self, node): raise StructureException("Expressions without assignment are disallowed", node) fn_type = get_exact_type_from_node(node.value.func) + if is_type_t(fn_type, EventT): raise StructureException("To call an event you must use the `log` statement", node) if is_type_t(fn_type, StructT): raise StructureException("Struct creation without assignment is disallowed", node) - if isinstance(fn_type, ContractFunctionT): - if ( - fn_type.mutability > StateMutability.VIEW - and self.func.mutability <= StateMutability.VIEW - ): - raise StateAccessViolation( - f"Cannot call a mutating function from a {self.func.mutability.value} function", - node, - ) - - if ( - self.func.mutability == StateMutability.PURE - and fn_type.mutability != StateMutability.PURE - ): - raise StateAccessViolation( - "Cannot call non-pure function from a pure function", node - ) - - if isinstance(fn_type, MemberFunctionT) and fn_type.is_modifying: - # it's a dotted function call like dynarray.pop() - expr_info = get_expr_info(node.value.func.value) - expr_info.validate_modification(node, self.func.mutability) - # NOTE: fetch_call_return validates call args. return_value = map_void(fn_type.fetch_call_return(node.value)) if ( @@ -350,96 +465,68 @@ def visit_Expr(self, node): ) self.expr_visitor.visit(node.value, return_value) + def _analyse_range_iter(self, iter_node, target_type): + # iteration via range() + if iter_node.get("func.id") != "range": + raise IteratorException("Cannot iterate over the result of a function call", iter_node) + _validate_range_call(iter_node) + + args = iter_node.args + kwargs = [s.value for s in iter_node.keywords] + for arg in (*args, *kwargs): + self.expr_visitor.visit(arg, target_type) + + def _analyse_list_iter(self, iter_node, target_type): + # iteration over a variable or literal list + iter_val = iter_node + if iter_val.has_folded_value: + iter_val = iter_val.get_folded_value() + + if isinstance(iter_val, vy_ast.List): + len_ = len(iter_val.elements) + if len_ == 0: + raise StructureException("For loop must have at least 1 iteration", iter_node) + iter_type = SArrayT(target_type, len_) + else: + try: + iter_type = get_exact_type_from_node(iter_node) + except (InvalidType, StructureException): + raise InvalidType("Not an iterable type", iter_node) + + # CMC 2024-02-09 TODO: use validate_expected_type once we have DArrays + # with generic length. + if not isinstance(iter_type, (DArrayT, SArrayT)): + raise InvalidType("Not an iterable type", iter_node) + + self.expr_visitor.visit(iter_node, iter_type) + + # get the root varinfo from iter_val in case we need to peer + # through folded constants + return _get_variable_access(iter_val) + def visit_For(self, node): if not isinstance(node.target.target, vy_ast.Name): raise StructureException("Invalid syntax for loop iterator", node.target.target) target_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY) + iter_var = None if isinstance(node.iter, vy_ast.Call): - # iteration via range() - if node.iter.get("func.id") != "range": - raise IteratorException( - "Cannot iterate over the result of a function call", node.iter - ) - _validate_range_call(node.iter) - + self._analyse_range_iter(node.iter, target_type) else: - # iteration over a variable or literal list - iter_val = node.iter.get_folded_value() if node.iter.has_folded_value else node.iter - if isinstance(iter_val, vy_ast.List) and len(iter_val.elements) == 0: - raise StructureException("For loop must have at least 1 iteration", node.iter) - - if not any( - isinstance(i, (DArrayT, SArrayT)) for i in get_possible_types_from_node(node.iter) - ): - raise InvalidType("Not an iterable type", node.iter) - - if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): - # check for references to the iterated value within the body of the loop - assign = _check_iterator_modification(node.iter, node) - if assign: - raise ImmutableViolation("Cannot modify array during iteration", assign) - - # Check if `iter` is a storage variable. get_descendants` is used to check for - # nested `self` (e.g. structs) - # NOTE: this analysis will be borked once stateful modules are allowed! - iter_is_storage_var = ( - isinstance(node.iter, vy_ast.Attribute) - and len(node.iter.get_descendants(vy_ast.Name, {"id": "self"})) > 0 - ) + iter_var = self._analyse_list_iter(node.iter, target_type) - if iter_is_storage_var: - # check if iterated value may be modified by function calls inside the loop - iter_name = node.iter.attr - for call_node in node.get_descendants(vy_ast.Call, {"func.value.id": "self"}): - fn_name = call_node.func.attr - - fn_node = self.vyper_module.get_children(vy_ast.FunctionDef, {"name": fn_name})[0] - if _check_iterator_modification(node.iter, fn_node): - # check for direct modification - raise ImmutableViolation( - f"Cannot call '{fn_name}' inside for loop, it potentially " - f"modifies iterated storage variable '{iter_name}'", - call_node, - ) - - for reachable_t in ( - self.namespace["self"].typ.members[fn_name].reachable_internal_functions - ): - # check for indirect modification - name = reachable_t.name - fn_node = self.vyper_module.get_children(vy_ast.FunctionDef, {"name": name})[0] - if _check_iterator_modification(node.iter, fn_node): - raise ImmutableViolation( - f"Cannot call '{fn_name}' inside for loop, it may call to '{name}' " - f"which potentially modifies iterated storage variable '{iter_name}'", - call_node, - ) - - target_name = node.target.target.id - with self.namespace.enter_scope(): + with self.namespace.enter_scope(), self.enter_for_loop(iter_var): + target_name = node.target.target.id + # maybe we should introduce a new Modifiability: LOOP_VARIABLE self.namespace[target_name] = VarInfo( target_type, modifiability=Modifiability.RUNTIME_CONSTANT ) + self.expr_visitor.visit(node.target.target, target_type) for stmt in node.body: self.visit(stmt) - self.expr_visitor.visit(node.target.target, target_type) - - if isinstance(node.iter, vy_ast.List): - len_ = len(node.iter.elements) - self.expr_visitor.visit(node.iter, SArrayT(target_type, len_)) - elif isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": - args = node.iter.args - kwargs = [s.value for s in node.iter.keywords] - for arg in (*args, *kwargs): - self.expr_visitor.visit(arg, target_type) - else: - iter_type = get_exact_type_from_node(node.iter) - self.expr_visitor.visit(node.iter, iter_type) - def visit_If(self, node): self.expr_visitor.visit(node.test, BoolT()) with self.namespace.enter_scope(): @@ -457,7 +544,7 @@ def visit_Log(self, node): raise StructureException("Value is not an event", node.value) if self.func.mutability <= StateMutability.VIEW: raise StructureException( - f"Cannot emit logs from {self.func.mutability.value.lower()} functions", node + f"Cannot emit logs from {self.func.mutability} functions", node ) t = map_void(f.fetch_call_return(node.value)) # CMC 2024-02-05 annotate the event type for codegen usage @@ -493,10 +580,20 @@ def visit_Return(self, node): class ExprVisitor(VyperNodeVisitorBase): - scope_name = "function" + def __init__(self, function_analyzer: Optional[FunctionAnalyzer] = None): + self.function_analyzer = function_analyzer + + @property + def func(self): + if self.function_analyzer is None: + return None + return self.function_analyzer.func - def __init__(self, fn_node: Optional[ContractFunctionT] = None): - self.func = fn_node + @property + def scope_name(self): + if self.func is not None: + return "function" + return "module" def visit(self, node, typ): if typ is not VOID_TYPE and not isinstance(typ, TYPE_T): @@ -509,6 +606,38 @@ def visit(self, node, typ): # annotate node._metadata["type"] = typ + if not isinstance(typ, TYPE_T): + info = get_expr_info(node) # get_expr_info fills in node._expr_info + + # log variable accesses. + # (note writes will get logged as both read+write) + var_access = _get_variable_access(node) + if var_access is not None: + info._reads.add(var_access) + + if self.function_analyzer: + for s in self.function_analyzer.loop_variables: + if s is None: + continue + + for v in info._writes: + if not v.contains(s): + continue + + msg = "Cannot modify loop variable" + var = s.variable + if var.decl_node is not None: + msg += f" `{var.decl_node.target.id}`" + raise ImmutableViolation(msg, var.decl_node, node) + + variable_accesses = info._writes | info._reads + for s in variable_accesses: + if s.variable.is_module_variable(): + self.function_analyzer._check_module_use(node) + + self.func.mark_variable_writes(info._writes) + self.func.mark_variable_reads(info._reads) + # validate and annotate folded value if node.has_folded_value: folded_node = node.get_folded_value() @@ -547,45 +676,82 @@ def visit_BoolOp(self, node: vy_ast.BoolOp, typ: VyperType) -> None: for value in node.values: self.visit(value, BoolT()) + def _check_call_mutability(self, call_mutability: StateMutability): + # note: payable can be called from nonpayable functions + ok = ( + call_mutability <= self.func.mutability + or self.func.mutability >= StateMutability.NONPAYABLE + ) + if not ok: + msg = f"Cannot call a {call_mutability} function from a {self.func.mutability} function" + raise StateAccessViolation(msg) + def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: - call_type = get_exact_type_from_node(node.func) - self.visit(node.func, call_type) + func_info = get_expr_info(node.func, is_callable=True) + func_type = func_info.typ - if isinstance(call_type, ContractFunctionT): + if isinstance(func_type, ContractFunctionT): # function calls - if self.func and call_type.is_internal: - self.func.called_functions.add(call_type) - for arg, typ in zip(node.args, call_type.argument_types): + + if not func_type.from_interface: + for s in func_type.get_variable_writes(): + if s.variable.is_module_variable(): + func_info._writes.add(s) + for s in func_type.get_variable_reads(): + if s.variable.is_module_variable(): + func_info._reads.add(s) + + if self.function_analyzer: + self._check_call_mutability(func_type.mutability) + + for s in func_type.get_variable_accesses(): + if s.variable.is_module_variable(): + self.function_analyzer._check_module_use(node.func) + + if func_type.is_deploy and not self.func.is_deploy: + raise CallViolation( + f"Cannot call an @{func_type.visibility} function from " + f"an @{self.func.visibility} function!", + node, + ) + + for arg, typ in zip(node.args, func_type.argument_types): self.visit(arg, typ) for kwarg in node.keywords: # We should only see special kwargs - typ = call_type.call_site_kwargs[kwarg.arg].typ + typ = func_type.call_site_kwargs[kwarg.arg].typ self.visit(kwarg.value, typ) - elif is_type_t(call_type, EventT): + elif is_type_t(func_type, EventT): # events have no kwargs - expected_types = call_type.typedef.arguments.values() + expected_types = func_type.typedef.arguments.values() # type: ignore for arg, typ in zip(node.args, expected_types): self.visit(arg, typ) - elif is_type_t(call_type, StructT): + elif is_type_t(func_type, StructT): # struct ctors # ctors have no kwargs - expected_types = call_type.typedef.members.values() + expected_types = func_type.typedef.members.values() # type: ignore for value, arg_type in zip(node.args[0].values, expected_types): self.visit(value, arg_type) - elif isinstance(call_type, MemberFunctionT): - assert len(node.args) == len(call_type.arg_types) - for arg, arg_type in zip(node.args, call_type.arg_types): + elif isinstance(func_type, MemberFunctionT): + if func_type.is_modifying and self.function_analyzer is not None: + # TODO refactor this + assert isinstance(node.func, vy_ast.Attribute) # help mypy + self.function_analyzer._handle_modification(node.func.value) + assert len(node.args) == len(func_type.arg_types) + for arg, arg_type in zip(node.args, func_type.arg_types): self.visit(arg, arg_type) else: # builtin functions - arg_types = call_type.infer_arg_types(node, expected_return_typ=typ) + arg_types = func_type.infer_arg_types(node, expected_return_typ=typ) # type: ignore for arg, arg_type in zip(node.args, arg_types): self.visit(arg, arg_type) - kwarg_types = call_type.infer_kwarg_types(node) + kwarg_types = func_type.infer_kwarg_types(node) # type: ignore for kwarg in node.keywords: self.visit(kwarg.value, kwarg_types[kwarg.arg]) + self.visit(node.func, func_type) + def visit_Compare(self, node: vy_ast.Compare, typ: VyperType) -> None: if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)): # membership in list literal - `x in [a, b, c]` @@ -638,8 +804,10 @@ def visit_List(self, node: vy_ast.List, typ: VyperType) -> None: self.visit(element, typ.value_type) def visit_Name(self, node: vy_ast.Name, typ: VyperType) -> None: - if self.func and self.func.mutability == StateMutability.PURE: - _validate_self_reference(node) + if self.func: + # TODO: refactor to use expr_info mutability + if self.func.mutability == StateMutability.PURE: + _validate_self_reference(node) def visit_Subscript(self, node: vy_ast.Subscript, typ: VyperType) -> None: if isinstance(typ, TYPE_T): diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index a83c2f3b7d..a7d8300083 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -8,38 +8,50 @@ from vyper.compiler.input_bundle import ABIInput, FileInput, FilesystemInputBundle, InputBundle from vyper.evm.opcodes import version_check from vyper.exceptions import ( + BorrowException, CallViolation, DuplicateImport, ExceptionList, + ImmutableViolation, + InitializerException, InvalidLiteral, InvalidType, ModuleNotFound, NamespaceCollision, StateAccessViolation, StructureException, - SyntaxException, + UndeclaredDefinition, VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import ImportInfo, Modifiability, ModuleInfo, VarInfo +from vyper.semantics.analysis.base import ( + ImportInfo, + InitializesInfo, + Modifiability, + ModuleInfo, + ModuleOwnership, + UsesInfo, + VarInfo, +) from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.constant_folding import constant_fold from vyper.semantics.analysis.import_graph import ImportGraph from vyper.semantics.analysis.local import ExprVisitor, validate_functions -from vyper.semantics.analysis.utils import check_modifiability, get_exact_type_from_node +from vyper.semantics.analysis.utils import ( + check_modifiability, + get_exact_type_from_node, + get_expr_info, +) from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace from vyper.semantics.types import EventT, FlagT, InterfaceT, StructT from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.module import ModuleT from vyper.semantics.types.utils import type_from_annotation +from vyper.utils import OrderedSet -def validate_semantics(module_ast, input_bundle, is_interface=False) -> ModuleT: - return validate_semantics_r(module_ast, input_bundle, ImportGraph(), is_interface) - - -def validate_semantics_r( +def validate_module_semantics_r( module_ast: vy_ast.Module, input_bundle: InputBundle, import_graph: ImportGraph, @@ -49,6 +61,11 @@ def validate_semantics_r( Analyze a Vyper module AST node, add all module-level objects to the namespace, type-check/validate semantics and annotate with type and analysis info """ + if "type" in module_ast._metadata: + # we don't need to analyse again, skip out + assert isinstance(module_ast._metadata["type"], ModuleT) + return module_ast._metadata["type"] + validate_literal_nodes(module_ast) # validate semantics and annotate AST with type/semantics information @@ -64,6 +81,8 @@ def validate_semantics_r( # in `ContractFunction.from_vyi()` if not is_interface: validate_functions(module_ast) + analyzer.validate_initialized_modules() + analyzer.validate_used_modules() return ret @@ -121,11 +140,8 @@ def __init__( def analyze(self) -> ModuleT: # generate a `ModuleT` from the top-level node # note: also validates unique method ids - if "type" in self.ast._metadata: - assert isinstance(self.ast._metadata["type"], ModuleT) - # we don't need to analyse again, skip out - self.module_t = self.ast._metadata["type"] - return self.module_t + + assert "type" not in self.ast._metadata to_visit = self.ast.body.copy() @@ -138,6 +154,11 @@ def analyze(self) -> ModuleT: self.visit(node) to_visit.remove(node) + ownership_decls = self.ast.get_children((vy_ast.UsesDecl, vy_ast.InitializesDecl)) + for node in ownership_decls: + self.visit(node) + to_visit.remove(node) + # we can resolve constants after imports are handled. constant_fold(self.ast) @@ -179,6 +200,7 @@ def analyze(self) -> ModuleT: def analyze_call_graph(self): # get list of internal function calls made by each function + # CMC 2024-02-03 note: this could be cleaner in analysis/local.py function_defs = self.module_t.function_defs for func in function_defs: @@ -195,7 +217,9 @@ def analyze_call_graph(self): # we just want to be able to construct the call graph. continue - if isinstance(call_t, ContractFunctionT) and call_t.is_internal: + if isinstance(call_t, ContractFunctionT) and ( + call_t.is_internal or call_t.is_constructor + ): fn_t.called_functions.add(call_t) for func in function_defs: @@ -204,6 +228,106 @@ def analyze_call_graph(self): # compute reachable set and validate the call graph _compute_reachable_set(fn_t) + def validate_used_modules(self): + # check all `uses:` modules are actually used + should_use = {} + + module_t = self.ast._metadata["type"] + uses_decls = module_t.uses_decls + for decl in uses_decls: + info = decl._metadata["uses_info"] + for m in info.used_modules: + should_use[m.module_t] = (m, info) + + initialized_modules = {t.module_info.module_t: t for t in module_t.initialized_modules} + + all_used_modules = OrderedSet() + + for f in module_t.functions.values(): + for u in f.get_used_modules(): + all_used_modules.add(u.module_t) + + for used_module in all_used_modules: + if used_module in initialized_modules: + continue + + if used_module in should_use: + del should_use[used_module] + + if len(should_use) > 0: + err_list = ExceptionList() + for used_module_info, uses_info in should_use.values(): + msg = f"`{used_module_info.alias}` is declared as used, but " + msg += f"it is not actually used in {module_t}!" + hint = f"delete `uses: {used_module_info.alias}`" + err_list.append(BorrowException(msg, uses_info.node, hint=hint)) + + err_list.raise_if_not_empty() + + def validate_initialized_modules(self): + # check all `initializes:` modules have `__init__()` called exactly once + module_t = self.ast._metadata["type"] + should_initialize = {t.module_info.module_t: t for t in module_t.initialized_modules} + # don't call `__init__()` for modules which don't have + # `__init__()` function + for m in should_initialize.copy(): + for f in m.functions.values(): + if f.is_constructor: + break + else: + del should_initialize[m] + + init_calls = [] + for f in self.ast.get_children(vy_ast.FunctionDef): + if f._metadata["func_type"].is_constructor: + init_calls = f.get_descendants(vy_ast.Call) + break + + seen_initializers = {} + for call_node in init_calls: + expr_info = call_node.func._expr_info + if expr_info is None: + # this can happen for range() calls; CMC 2024-02-05 try to + # refactor so that range() is properly tagged. + continue + + call_t = call_node.func._expr_info.typ + + if not isinstance(call_t, ContractFunctionT): + continue + + if not call_t.is_constructor: + continue + + # XXX: check this works as expected for nested attributes + initialized_module = call_node.func.value._expr_info.module_info + + if initialized_module.module_t in seen_initializers: + seen_location = seen_initializers[initialized_module.module_t] + msg = f"tried to initialize `{initialized_module.alias}`, " + msg += "but its __init__() function was already called!" + raise InitializerException(msg, call_node.func, seen_location) + + if initialized_module.module_t not in should_initialize: + msg = f"tried to initialize `{initialized_module.alias}`, " + msg += "but it is not in initializer list!" + hint = f"add `initializes: {initialized_module.alias}` " + hint += "as a top-level statement to your contract" + raise InitializerException(msg, call_node.func, hint=hint) + + del should_initialize[initialized_module.module_t] + seen_initializers[initialized_module.module_t] = call_node.func + + if len(should_initialize) > 0: + err_list = ExceptionList() + for s in should_initialize.values(): + msg = "not initialized!" + hint = f"add `{s.module_info.alias}.__init__()` to " + hint += "your `__init__()` function" + err_list.append(InitializerException(msg, s.node, hint=hint)) + + err_list.raise_if_not_empty() + def _ast_from_file(self, file: FileInput) -> vy_ast.Module: # cache ast if we have seen it before. # this gives us the additional property of object equality on @@ -218,10 +342,100 @@ def visit_ImplementsDecl(self, node): type_ = type_from_annotation(node.annotation) if not isinstance(type_, InterfaceT): - raise StructureException("Invalid interface name", node.annotation) + raise StructureException("not an interface!", node.annotation) type_.validate_implements(node) + def visit_UsesDecl(self, node): + # TODO: check duplicate uses declarations, e.g. + # uses: x + # ... + # uses: x + items = vy_ast.as_tuple(node.annotation) + + used_modules = [] + + for item in items: + module_info = get_expr_info(item).module_info + if module_info is None: + raise StructureException("not a valid module!", item) + + # note: try to refactor - not a huge fan of mutating the + # ModuleInfo after it's constructed + module_info.set_ownership(ModuleOwnership.USES, item) + + used_modules.append(module_info) + + node._metadata["uses_info"] = UsesInfo(used_modules, node) + + def visit_InitializesDecl(self, node): + module_ref = node.annotation + dependencies_ast = () + if isinstance(module_ref, vy_ast.Subscript): + dependencies_ast = vy_ast.as_tuple(module_ref.slice) + module_ref = module_ref.value + + # postcondition of InitializesDecl.validates() + assert isinstance(module_ref, (vy_ast.Name, vy_ast.Attribute)) + + module_info = get_expr_info(module_ref).module_info + if module_info is None: + raise StructureException("Not a module!", module_ref) + + used_modules = {i.module_t: i for i in module_info.module_t.used_modules} + + dependencies = [] + for named_expr in dependencies_ast: + assert isinstance(named_expr, vy_ast.NamedExpr) + + rhs_module = get_expr_info(named_expr.value).module_info + + with module_info.module_node.namespace(): + # lhs of the named_expr is evaluated in the namespace of the + # initialized module! + try: + lhs_module = get_expr_info(named_expr.target).module_info + except VyperException as e: + # try to report a common problem - user names the module in + # the current namespace instead of the initialized module + # namespace. + + # search for the module in the initialized module + found_module = module_info.module_t.find_module_info(rhs_module.module_t) + if found_module is not None: + msg = f"unknown module `{named_expr.target.id}`" + hint = f"did you mean `{found_module.alias} := {rhs_module.alias}`?" + raise UndeclaredDefinition(msg, named_expr.target, hint=hint) + + raise e from None + + if lhs_module.module_t != rhs_module.module_t: + raise StructureException( + f"{lhs_module.alias} is not {rhs_module.alias}!", named_expr + ) + dependencies.append(lhs_module) + + if lhs_module.module_t not in used_modules: + raise InitializerException( + f"`{module_info.alias}` is initialized with `{lhs_module.alias}`, " + f"but `{module_info.alias}` does not use `{lhs_module.alias}`!", + named_expr, + ) + + del used_modules[lhs_module.module_t] + + if len(used_modules) > 0: + item = next(iter(used_modules.values())) # just pick one + msg = f"`{module_info.alias}` uses `{item.alias}`, but it is not " + msg += f"initialized with `{item.alias}`" + hint = f"add `{item.alias}` to its initializer list" + raise InitializerException(msg, node, hint=hint) + + # note: try to refactor. not a huge fan of mutating the + # ModuleInfo after it's constructed + module_info.set_ownership(ModuleOwnership.INITIALIZES, node) + node._metadata["initializes_info"] = InitializesInfo(module_info, dependencies, node) + def visit_VariableDecl(self, node): name = node.get("target.id") if name is None: @@ -250,7 +464,7 @@ def visit_VariableDecl(self, node): if len(wrong_self_attribute) > 0 else "Immutable definition requires an assignment in the constructor" ) - raise SyntaxException(message, node.node_source_code, node.lineno, node.col_offset) + raise ImmutableViolation(message, node) data_loc = ( DataLocation.CODE @@ -364,11 +578,10 @@ def visit_Import(self, node): # don't handle things like `import x.y` if "." in alias: + msg = "import requires an accompanying `as` statement" suggested_alias = node.name[node.name.rfind(".") :] - suggestion = f"hint: try `import {node.name} as {suggested_alias}`" - raise StructureException( - f"import requires an accompanying `as` statement ({suggestion})", node - ) + hint = f"try `import {node.name} as {suggested_alias}`" + raise StructureException(msg, node, hint=hint) self._add_import(node, 0, node.name, alias) @@ -436,14 +649,14 @@ def _load_import_helper( module_ast = self._ast_from_file(file) with override_global_namespace(Namespace()): - module_t = validate_semantics_r( + module_t = validate_module_semantics_r( module_ast, self.input_bundle, import_graph=self._import_graph, is_interface=False, ) - return ModuleInfo(module_t) + return ModuleInfo(module_t, alias) except FileNotFoundError as e: # escape `e` from the block scope, it can make things @@ -456,7 +669,7 @@ def _load_import_helper( module_ast = self._ast_from_file(file) with override_global_namespace(Namespace()): - validate_semantics_r( + validate_module_semantics_r( module_ast, self.input_bundle, import_graph=self._import_graph, @@ -481,7 +694,7 @@ def _load_import_helper( raise ModuleNotFound(module_str, node) from err -def _parse_and_fold_ast(file: FileInput) -> vy_ast.VyperNode: +def _parse_and_fold_ast(file: FileInput) -> vy_ast.Module: ret = vy_ast.parse_to_ast( file.source_code, source_id=file.source_id, @@ -542,5 +755,7 @@ def _load_builtin_import(level: int, module_str: str) -> InterfaceT: interface_ast = _parse_and_fold_ast(file) with override_global_namespace(Namespace()): - module_t = validate_semantics(interface_ast, input_bundle, is_interface=True) + module_t = validate_module_semantics_r( + interface_ast, input_bundle, ImportGraph(), is_interface=True + ) return module_t.interface diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index abbf6a68cc..034cd8c46e 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -61,8 +61,8 @@ class _ExprAnalyser: def __init__(self): self.namespace = get_namespace() - def get_expr_info(self, node: vy_ast.VyperNode) -> ExprInfo: - t = self.get_exact_type_from_node(node) + def get_expr_info(self, node: vy_ast.VyperNode, is_callable: bool = False) -> ExprInfo: + t = self.get_exact_type_from_node(node, include_type_exprs=is_callable) # if it's a Name, we have varinfo for it if isinstance(node, vy_ast.Name): @@ -74,33 +74,29 @@ def get_expr_info(self, node: vy_ast.VyperNode) -> ExprInfo: if isinstance(info, ModuleInfo): return ExprInfo.from_moduleinfo(info) - raise CompilerPanic("unreachable!", node) + if isinstance(info, VyperType): + return ExprInfo(TYPE_T(info)) + + raise CompilerPanic(f"unreachable! {info}", node) if isinstance(node, vy_ast.Attribute): # if it's an Attr, we check the parent exprinfo and # propagate the parent exprinfo members down into the new expr # note: Attribute(expr value, identifier attr) - name = node.attr - info = self.get_expr_info(node.value) + info = self.get_expr_info(node.value, is_callable=is_callable) + attr = node.attr - t = info.typ.get_member(name, node) + t = info.typ.get_member(attr, node) # it's a top-level variable if isinstance(t, VarInfo): - return ExprInfo.from_varinfo(t) + return ExprInfo.from_varinfo(t, attr=attr) - # it's something else, like my_struct.foo - return info.copy_with_type(t) + if isinstance(t, ModuleInfo): + return ExprInfo.from_moduleinfo(t, attr=attr) - if isinstance(node, vy_ast.Tuple): - # always use the most restrictive location re: modification - # kludge! for validate_modification in local analysis of Assign - types = [self.get_expr_info(n) for n in node.elements] - location = sorted((i.location for i in types), key=lambda k: k.value)[-1] - modifiability = sorted((i.modifiability for i in types), key=lambda k: k.value)[-1] - - return ExprInfo(t, location=location, modifiability=modifiability) + return info.copy_with_type(t, attr=attr) # If it's a Subscript, propagate the subscriptable varinfo if isinstance(node, vy_ast.Subscript): @@ -184,6 +180,7 @@ def _find_fn(self, node): def types_from_Attribute(self, node): is_self_reference = node.get("value.id") == "self" + # variable attribute, e.g. `foo.bar` t = self.get_exact_type_from_node(node.value, include_type_exprs=True) name = node.attr @@ -476,8 +473,10 @@ def get_exact_type_from_node(node): return _ExprAnalyser().get_exact_type_from_node(node, include_type_exprs=True) -def get_expr_info(node: vy_ast.VyperNode) -> ExprInfo: - return _ExprAnalyser().get_expr_info(node) +def get_expr_info(node: vy_ast.ExprNode, is_callable: bool = False) -> ExprInfo: + if node._expr_info is None: + node._expr_info = _ExprAnalyser().get_expr_info(node, is_callable) + return node._expr_info def get_common_types(*nodes: vy_ast.VyperNode, filter_fn: Callable = None) -> List: @@ -639,7 +638,7 @@ def validate_unique_method_ids(functions: List) -> None: seen.add(method_id) -def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) -> bool: +def check_modifiability(node: vy_ast.ExprNode, modifiability: Modifiability) -> bool: """ Check if the given node is not more modifiable than the given modifiability. """ @@ -665,5 +664,5 @@ def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) -> if hasattr(call_type, "check_modifiability_for_call"): return call_type.check_modifiability_for_call(node, modifiability) - value_type = get_expr_info(node) - return value_type.modifiability >= modifiability + info = get_expr_info(node) + return info.modifiability <= modifiability diff --git a/vyper/semantics/data_locations.py b/vyper/semantics/data_locations.py index cecea35a60..06245aa90d 100644 --- a/vyper/semantics/data_locations.py +++ b/vyper/semantics/data_locations.py @@ -1,10 +1,12 @@ import enum +from vyper.utils import StringEnum -class DataLocation(enum.Enum): - UNSET = 0 - MEMORY = 1 - STORAGE = 2 - CALLDATA = 3 - CODE = 4 - TRANSIENT = 5 + +class DataLocation(StringEnum): + UNSET = enum.auto() + MEMORY = enum.auto() + STORAGE = enum.auto() + CALLDATA = enum.auto() + CODE = enum.auto() + TRANSIENT = enum.auto() diff --git a/vyper/semantics/environment.py b/vyper/semantics/environment.py index 38bac0a63d..94a26157af 100644 --- a/vyper/semantics/environment.py +++ b/vyper/semantics/environment.py @@ -1,7 +1,7 @@ from typing import Dict from vyper.semantics.analysis.base import Modifiability, VarInfo -from vyper.semantics.types import AddressT, BytesT, VyperType +from vyper.semantics.types import AddressT, BytesT, SelfT, VyperType from vyper.semantics.types.shortcuts import BYTES32_T, UINT256_T @@ -57,7 +57,7 @@ def get_constant_vars() -> Dict: return result -MUTABLE_ENVIRONMENT_VARS: Dict[str, type] = {"self": AddressT} +MUTABLE_ENVIRONMENT_VARS: Dict[str, type] = {"self": SelfT} def get_mutable_vars() -> Dict: diff --git a/vyper/semantics/types/__init__.py b/vyper/semantics/types/__init__.py index a04632b96f..59a20dd99f 100644 --- a/vyper/semantics/types/__init__.py +++ b/vyper/semantics/types/__init__.py @@ -3,7 +3,7 @@ from .bytestrings import BytesT, StringT, _BytestringT from .function import MemberFunctionT from .module import InterfaceT -from .primitives import AddressT, BoolT, BytesM_T, DecimalT, IntegerT +from .primitives import AddressT, BoolT, BytesM_T, DecimalT, IntegerT, SelfT from .subscriptable import DArrayT, HashMapT, SArrayT, TupleT from .user import EventT, FlagT, StructT diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index d659276ee0..c5e10b52be 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -13,6 +13,7 @@ UnknownAttribute, ) from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions +from vyper.semantics.data_locations import DataLocation # Some fake type with an overridden `compare_type` which accepts any RHS @@ -25,7 +26,11 @@ def __init__(self, type_): self.type_ = type_ def compare_type(self, other): - return isinstance(other, self.type_) or self == other + if isinstance(other, self.type_): + return True + # compare two GenericTypeAcceptors -- they are the same if the base + # type is the same + return isinstance(other, self.__class__) and other.type_ == self.type_ class VyperType: @@ -91,6 +96,8 @@ def __hash__(self): return hash(self._get_equality_attrs()) def __eq__(self, other): + if self is other: + return True return ( type(self) is type(other) and self._get_equality_attrs() == other._get_equality_attrs() ) @@ -118,6 +125,16 @@ def abi_type(self) -> ABIType: """ raise CompilerPanic("Method must be implemented by the inherited class") + def get_size_in(self, location: DataLocation): + if location in (DataLocation.STORAGE, DataLocation.TRANSIENT): + return self.storage_size_in_words + if location == DataLocation.MEMORY: + return self.memory_bytes_required + if location == DataLocation.CODE: + return self.memory_bytes_required + + raise CompilerPanic("unreachable: invalid location {location}") # pragma: nocover + @property def memory_bytes_required(self) -> int: # alias for API compatibility with codegen @@ -341,8 +358,10 @@ def map_void(typ: Optional[VyperType]) -> VyperType: # A type type. Used internally for types which can live in expression # position, ex. constructors (events, interfaces and structs), and also # certain builtins which take types as parameters -class TYPE_T: +class TYPE_T(VyperType): def __init__(self, typedef): + super().__init__() + self.typedef = typedef def __repr__(self): diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 2d92370b9d..705470a798 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -19,8 +19,10 @@ from vyper.semantics.analysis.base import ( FunctionVisibility, Modifiability, + ModuleInfo, StateMutability, - StorageSlot, + VarAccess, + VarOffset, ) from vyper.semantics.analysis.utils import ( check_modifiability, @@ -90,6 +92,7 @@ def __init__( return_type: Optional[VyperType], function_visibility: FunctionVisibility, state_mutability: StateMutability, + from_interface: bool = False, nonreentrant: Optional[str] = None, ast_def: Optional[vy_ast.VyperNode] = None, ) -> None: @@ -102,9 +105,12 @@ def __init__( self.visibility = function_visibility self.mutability = state_mutability self.nonreentrant = nonreentrant + self.from_interface = from_interface self.ast_def = ast_def + self._analysed = False + # a list of internal functions this function calls. # to be populated during analysis self.called_functions: OrderedSet[ContractFunctionT] = OrderedSet() @@ -112,10 +118,52 @@ def __init__( # recursively reachable from this function self.reachable_internal_functions: OrderedSet[ContractFunctionT] = OrderedSet() + # writes to variables from this function + self._variable_writes: OrderedSet[VarAccess] = OrderedSet() + + # reads of variables from this function + self._variable_reads: OrderedSet[VarAccess] = OrderedSet() + + # list of modules used (accessed state) by this function + self._used_modules: OrderedSet[ModuleInfo] = OrderedSet() + # to be populated during codegen self._ir_info: Any = None self._function_id: Optional[int] = None + def mark_analysed(self): + assert not self._analysed + self._analysed = True + + @property + def analysed(self): + return self._analysed + + def get_variable_reads(self): + return self._variable_reads + + def get_variable_writes(self): + return self._variable_writes + + def get_variable_accesses(self): + return self._variable_reads | self._variable_writes + + def get_used_modules(self): + return self._used_modules + + def mark_used_module(self, module_info): + self._used_modules.add(module_info) + + def mark_variable_writes(self, var_infos): + self._variable_writes.update(var_infos) + + def mark_variable_reads(self, var_infos): + self._variable_reads.update(var_infos) + + @property + def modifiability(self): + return Modifiability.from_state_mutability(self.mutability) + @cached_property def call_site_kwargs(self): # special kwargs that are allowed in call site @@ -170,6 +218,7 @@ def from_abi(cls, abi: dict) -> "ContractFunctionT": positional_args, [], return_type, + from_interface=True, function_visibility=FunctionVisibility.EXTERNAL, state_mutability=StateMutability.from_abi(abi), ) @@ -229,6 +278,7 @@ def from_InterfaceDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": return_type, function_visibility, state_mutability, + from_interface=True, nonreentrant=None, ast_def=funcdef, ) @@ -269,9 +319,11 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": if len(funcdef.body) != 1 or not isinstance(funcdef.body[0].get("value"), vy_ast.Ellipsis): raise FunctionDeclarationException( - "function body in an interface can only be ...!", funcdef + "function body in an interface can only be `...`!", funcdef ) + assert function_visibility is not None # mypy hint + return cls( funcdef.name, positional_args, @@ -279,6 +331,7 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": return_type, function_visibility, state_mutability, + from_interface=True, nonreentrant=nonreentrant_key, ast_def=funcdef, ) @@ -314,13 +367,19 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": "Default function may not receive any arguments", funcdef.args.args[0] ) + if function_visibility == FunctionVisibility.DEPLOY and funcdef.name != "__init__": + raise FunctionDeclarationException( + "Only constructors can be marked as `@deploy`!", funcdef + ) if funcdef.name == "__init__": - if ( - state_mutability in (StateMutability.PURE, StateMutability.VIEW) - or function_visibility == FunctionVisibility.INTERNAL - ): + if state_mutability in (StateMutability.PURE, StateMutability.VIEW): raise FunctionDeclarationException( - "Constructor cannot be marked as `@pure`, `@view` or `@internal`", funcdef + "Constructor cannot be marked as `@pure` or `@view`", funcdef + ) + if function_visibility != FunctionVisibility.DEPLOY: + raise FunctionDeclarationException( + f"Constructor must be marked as `@deploy`, not `@{function_visibility}`", + funcdef, ) if return_type is not None: raise FunctionDeclarationException( @@ -333,6 +392,9 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": "Constructor may not use default arguments", funcdef.args.defaults[0] ) + # sanity check + assert function_visibility is not None + return cls( funcdef.name, positional_args, @@ -340,18 +402,16 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": return_type, function_visibility, state_mutability, + from_interface=False, nonreentrant=nonreentrant_key, ast_def=funcdef, ) - def set_reentrancy_key_position(self, position: StorageSlot) -> None: + def set_reentrancy_key_position(self, position: VarOffset) -> None: if hasattr(self, "reentrancy_key_position"): raise CompilerPanic("Position was already assigned") if self.nonreentrant is None: raise CompilerPanic(f"No reentrant key {self}") - # sanity check even though implied by the type - if position._location != DataLocation.STORAGE: - raise CompilerPanic("Non-storage reentrant key") self.reentrancy_key_position = position @classmethod @@ -383,6 +443,7 @@ def getter_from_VariableDecl(cls, node: vy_ast.VariableDecl) -> "ContractFunctio args, [], return_type, + from_interface=False, function_visibility=FunctionVisibility.EXTERNAL, state_mutability=StateMutability.VIEW, ast_def=node, @@ -456,6 +517,14 @@ def is_external(self) -> bool: def is_internal(self) -> bool: return self.visibility == FunctionVisibility.INTERNAL + @property + def is_deploy(self) -> bool: + return self.visibility == FunctionVisibility.DEPLOY + + @property + def is_constructor(self) -> bool: + return self.name == "__init__" + @property def is_mutable(self) -> bool: return self.mutability > StateMutability.VIEW @@ -464,10 +533,6 @@ def is_mutable(self) -> bool: def is_payable(self) -> bool: return self.mutability == StateMutability.PAYABLE - @property - def is_constructor(self) -> bool: - return self.name == "__init__" - @property def is_fallback(self) -> bool: return self.name == "__default__" @@ -535,20 +600,14 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: modified_line = re.sub( kwarg_pattern, kwarg.value.node_source_code, node.node_source_code ) - error_suggestion = ( - f"\n(hint: Try removing the kwarg: `{modified_line}`)" - if modified_line != node.node_source_code - else "" - ) - raise ArgumentException( - ( - "Usage of kwarg in Vyper is restricted to " - + ", ".join([f"{k}=" for k in self.call_site_kwargs.keys()]) - + f". {error_suggestion}" - ), - kwarg, - ) + msg = "Usage of kwarg in Vyper is restricted to " + msg += ", ".join([f"{k}=" for k in self.call_site_kwargs.keys()]) + + hint = None + if modified_line != node.node_source_code: + hint = f"Try removing the kwarg: `{modified_line}`" + raise ArgumentException(msg, kwarg, hint=hint) return self.return_type @@ -601,7 +660,7 @@ def _parse_return_type(funcdef: vy_ast.FunctionDef) -> Optional[VyperType]: def _parse_decorators( funcdef: vy_ast.FunctionDef, -) -> tuple[FunctionVisibility, StateMutability, Optional[str]]: +) -> tuple[Optional[FunctionVisibility], StateMutability, Optional[str]]: function_visibility = None state_mutability = None nonreentrant_key = None @@ -632,7 +691,9 @@ def _parse_decorators( if FunctionVisibility.is_valid_value(decorator.id): if function_visibility is not None: raise FunctionDeclarationException( - f"Visibility is already set to: {function_visibility}", funcdef + f"Visibility is already set to: {function_visibility}", + decorator, + hint="only one visibility decorator is allowed per function", ) function_visibility = FunctionVisibility(decorator.id) @@ -748,6 +809,10 @@ def __init__( self.return_type = return_type self.is_modifying = is_modifying + @property + def modifiability(self): + return Modifiability.MODIFIABLE if self.is_modifying else Modifiability.RUNTIME_CONSTANT + def __repr__(self): return f"{self.underlying_type._id} member function '{self.name}'" diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index ee1da22a87..86840f4f91 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Optional +from typing import TYPE_CHECKING, Optional from vyper import ast as vy_ast from vyper.abi_types import ABI_Address, ABIType @@ -16,12 +16,16 @@ validate_expected_type, validate_unique_method_ids, ) +from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.primitives import AddressT from vyper.semantics.types.user import EventT, StructT, _UserType +if TYPE_CHECKING: + from vyper.semantics.analysis.base import ModuleInfo + class InterfaceT(_UserType): _type_members = {"address": AddressT()} @@ -234,7 +238,7 @@ def from_ModuleT(cls, module_t: "ModuleT") -> "InterfaceT": for node in module_t.function_defs: func_t = node._metadata["func_type"] - if not func_t.is_external: + if not (func_t.is_external or func_t.is_constructor): continue funcs.append((node.name, func_t)) @@ -276,6 +280,12 @@ def from_InterfaceDef(cls, node: vy_ast.InterfaceDef) -> "InterfaceT": # Datatype to store all module information. class ModuleT(VyperType): _attribute_in_annotation = True + _invalid_locations = ( + DataLocation.CALLDATA, + DataLocation.CODE, + DataLocation.MEMORY, + DataLocation.TRANSIENT, + ) def __init__(self, module: vy_ast.Module, name: Optional[str] = None): super().__init__() @@ -307,7 +317,6 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None): for i in self.interface_defs: # add the type of the interface so it can be used in call position self.add_member(i.name, TYPE_T(i._metadata["interface_type"])) # type: ignore - self._helper.add_member(i.name, TYPE_T(i._metadata["interface_type"])) # type: ignore for v in self.variable_decls: self.add_member(v.target.id, v.target._metadata["varinfo"]) @@ -316,6 +325,13 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None): import_info = i._metadata["import_info"] self.add_member(import_info.alias, import_info.typ) + if hasattr(import_info.typ, "module_t"): + self._helper.add_member(import_info.alias, TYPE_T(import_info.typ)) + + for name, interface_t in self.interfaces.items(): + # can access interfaces in type position + self._helper.add_member(name, TYPE_T(interface_t)) + # __eq__ is very strict on ModuleT - object equality! this is because we # don't want to reason about where a module came from (i.e. input bundle, # search path, symlinked vs normalized path, etc.) @@ -345,27 +361,97 @@ def struct_defs(self): def interface_defs(self): return self._module.get_children(vy_ast.InterfaceDef) + @cached_property + def interfaces(self) -> dict[str, InterfaceT]: + ret = {} + for i in self.interface_defs: + assert i.name not in ret # precondition + ret[i.name] = i._metadata["interface_type"] + + for i in self.import_stmts: + import_info = i._metadata["import_info"] + if isinstance(import_info.typ, InterfaceT): + assert import_info.alias not in ret # precondition + ret[import_info.alias] = import_info.typ + + return ret + @property def import_stmts(self): return self._module.get_children((vy_ast.Import, vy_ast.ImportFrom)) + @cached_property + def imported_modules(self) -> dict[str, "ModuleInfo"]: + ret = {} + for s in self.import_stmts: + info = s._metadata["import_info"] + module_info = info.typ + if isinstance(module_info, InterfaceT): + continue + ret[info.alias] = module_info + return ret + + def find_module_info(self, needle: "ModuleT") -> Optional["ModuleInfo"]: + for s in self.imported_modules.values(): + if s.module_t == needle: + return s + return None + @property def variable_decls(self): return self._module.get_children(vy_ast.VariableDecl) + @property + def uses_decls(self): + return self._module.get_children(vy_ast.UsesDecl) + + @property + def initializes_decls(self): + return self._module.get_children(vy_ast.InitializesDecl) + + @cached_property + def used_modules(self): + # modules which are written to + ret = [] + for node in self.uses_decls: + for used_module in node._metadata["uses_info"].used_modules: + ret.append(used_module) + return ret + + @property + def initialized_modules(self): + # modules which are initialized to + ret = [] + for node in self.initializes_decls: + info = node._metadata["initializes_info"] + ret.append(info) + return ret + @cached_property def variables(self): # variables that this module defines, ex. # `x: uint256` is a private storage variable named x return {s.target.id: s.target._metadata["varinfo"] for s in self.variable_decls} + @cached_property + def functions(self): + return {f.name: f._metadata["func_type"] for f in self.function_defs} + @cached_property def immutables(self): return [t for t in self.variables.values() if t.is_immutable] @cached_property def immutable_section_bytes(self): - return sum([imm.typ.memory_bytes_required for imm in self.immutables]) + ret = 0 + for s in self.immutables: + ret += s.typ.memory_bytes_required + + for initializes_info in self.initialized_modules: + module_t = initializes_info.module_info.module_t + ret += module_t.immutable_section_bytes + + return ret @cached_property def interface(self): diff --git a/vyper/semantics/types/primitives.py b/vyper/semantics/types/primitives.py index 07d1a21a94..d383f72ab2 100644 --- a/vyper/semantics/types/primitives.py +++ b/vyper/semantics/types/primitives.py @@ -340,3 +340,13 @@ def validate_literal(self, node: vy_ast.Constant) -> None: f"address, the correct checksummed form is: {checksum_encode(addr)}", node, ) + + +# type for "self" +# refactoring note: it might be best for this to be a ModuleT actually +class SelfT(AddressT): + _id = "self" + + def compare_type(self, other): + # compares true to AddressT + return isinstance(other, type(self)) or isinstance(self, type(other)) diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index 5564570536..c6a4531df8 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -117,16 +117,16 @@ def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType: if isinstance(node, vy_ast.Attribute): # ex. SomeModule.SomeStruct - # sanity check - we only allow modules/interfaces to be - # imported as `Name`s currently. - if not isinstance(node.value, vy_ast.Name): + if isinstance(node.value, vy_ast.Attribute): + module_or_interface = _type_from_annotation(node.value) + elif isinstance(node.value, vy_ast.Name): + try: + module_or_interface = namespace[node.value.id] # type: ignore + except UndeclaredDefinition: + raise InvalidType(err_msg, node) from None + else: raise InvalidType(err_msg, node) - try: - module_or_interface = namespace[node.value.id] # type: ignore - except UndeclaredDefinition: - raise InvalidType(err_msg, node) from None - if hasattr(module_or_interface, "module_t"): # i.e., it's a ModuleInfo module_or_interface = module_or_interface.module_t diff --git a/vyper/utils.py b/vyper/utils.py index 2349731b97..b2284eaba0 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -1,6 +1,7 @@ import binascii import contextlib import decimal +import enum import functools import sys import time @@ -8,7 +9,7 @@ import warnings from typing import Generic, List, TypeVar, Union -from vyper.exceptions import DecimalOverrideException, InvalidLiteral +from vyper.exceptions import CompilerPanic, DecimalOverrideException, InvalidLiteral _T = TypeVar("_T") @@ -62,6 +63,59 @@ def copy(self): return self.__class__(super().copy()) +class StringEnum(enum.Enum): + # Must be first, or else won't work, specifies what .value is + def _generate_next_value_(name, start, count, last_values): + return name.lower() + + # Override ValueError with our own internal exception + @classmethod + def _missing_(cls, value): + raise CompilerPanic(f"{value} is not a valid {cls.__name__}") + + @classmethod + def is_valid_value(cls, value: str) -> bool: + return value in set(o.value for o in cls) + + @classmethod + def options(cls) -> List["StringEnum"]: + return list(cls) + + @classmethod + def values(cls) -> List[str]: + return [v.value for v in cls.options()] + + # Comparison operations + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + raise CompilerPanic(f"bad comparison: ({type(other)}, {type(self)})") + return self is other + + # Python normally does __ne__(other) ==> not self.__eq__(other) + + def __lt__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + raise CompilerPanic(f"bad comparison: ({type(other)}, {type(self)})") + options = self.__class__.options() + return options.index(self) < options.index(other) # type: ignore + + def __le__(self, other: object) -> bool: + return self.__eq__(other) or self.__lt__(other) + + def __gt__(self, other: object) -> bool: + return not self.__le__(other) + + def __ge__(self, other: object) -> bool: + return not self.__lt__(other) + + def __str__(self) -> str: + return self.value + + def __hash__(self) -> int: + # let `dataclass` know that this class is not mutable + return super().__hash__() + + class DecimalContextOverride(decimal.Context): def __setattr__(self, name, value): if name == "prec":