Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix type of storage variables containing interface #2697

Merged
merged 8 commits into from
Mar 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions examples/factory/Exchange.vy
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@ interface Factory:
def register(): nonpayable


token: public(address)
factory: address
token: public(ERC20)
factory: Factory


@external
def __init__(_token: address, _factory: address):
def __init__(_token: ERC20, _factory: Factory):
self.token = _token
self.factory = _factory


@external
def initialize():
# Anyone can safely call this function because of EXTCODEHASH
Factory(self.factory).register()
self.factory.register()


# NOTE: This contract restricts trading to only be done by the factory.
Expand All @@ -28,13 +28,13 @@ def initialize():

@external
def receive(_from: address, _amt: uint256):
assert msg.sender == self.factory
success: bool = ERC20(self.token).transferFrom(_from, self, _amt)
assert msg.sender == self.factory.address
success: bool = self.token.transferFrom(_from, self, _amt)
assert success


@external
def transfer(_to: address, _amt: uint256):
assert msg.sender == self.factory
success: bool = ERC20(self.token).transfer(_to, _amt)
assert msg.sender == self.factory.address
success: bool = self.token.transfer(_to, _amt)
assert success
15 changes: 9 additions & 6 deletions examples/factory/Factory.vy
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from vyper.interfaces import ERC20

interface Exchange:
def token() -> address: view
def token() -> ERC20: view
def receive(_from: address, _amt: uint256): nonpayable
def transfer(_to: address, _amt: uint256): nonpayable


exchange_codehash: public(bytes32)
# Maps token addresses to exchange addresses
exchanges: public(HashMap[address, address])
exchanges: public(HashMap[ERC20, Exchange])


@external
Expand All @@ -30,12 +32,13 @@ def register():
# NOTE: Use exchange's token address because it should be globally unique
# NOTE: Should do checks that it hasn't already been set,
# which has to be rectified with any upgrade strategy.
self.exchanges[Exchange(msg.sender).token()] = msg.sender
exchange: Exchange = Exchange(msg.sender)
self.exchanges[exchange.token()] = exchange


@external
def trade(_token1: address, _token2: address, _amt: uint256):
def trade(_token1: ERC20, _token2: ERC20, _amt: uint256):
# Perform a straight exchange of token1 to token 2 (1:1 price)
# NOTE: Any practical implementation would need to solve the price oracle problem
Exchange(self.exchanges[_token1]).receive(msg.sender, _amt)
Exchange(self.exchanges[_token2]).transfer(msg.sender, _amt)
self.exchanges[_token1].receive(msg.sender, _amt)
self.exchanges[_token2].transfer(msg.sender, _amt)
17 changes: 17 additions & 0 deletions tests/parser/syntax/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,23 @@ def test() -> address:
def test():
b: address = self.a.address
""",
"""
interface MyInterface:
def some_func(): nonpayable

my_interface: MyInterface[3]
idx: uint256

@external
def __init__():
self.my_interface[self.idx] = MyInterface(ZERO_ADDRESS)
""",
"""
interface MyInterface:
def kick(): payable

kickers: HashMap[address, MyInterface]
""",
]


Expand Down
5 changes: 0 additions & 5 deletions vyper/ast/signatures/function_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,6 @@ def size(self):
return math.ceil(self.typ.memory_bytes_required / 32)


class ContractRecord(VariableRecord):
def __init__(self, *args):
super(ContractRecord, self).__init__(*args)


@dataclass
class FunctionArg:
name: str
Expand Down
4 changes: 2 additions & 2 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,9 +1081,9 @@ def struct_literals(expr, name, context):
def parse_value_expr(cls, expr, context):
return unwrap_location(cls(expr, context).lll_node)

# Parse an expression that represents an address in memory/calldata or storage.
# Parse an expression that represents a pointer to memory/calldata or storage.
@classmethod
def parse_variable_location(cls, expr, context):
def parse_pointer_expr(cls, expr, context):
o = cls(expr, context).lll_node
if not o.location:
raise StructureException("Looking for a variable location, instead got a value", expr)
Expand Down
47 changes: 6 additions & 41 deletions vyper/codegen/external_call.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import vyper.utils as util
from vyper import ast as vy_ast
from vyper.codegen.abi_encoder import abi_encode
from vyper.codegen.core import (
calculate_type_for_external_return,
Expand All @@ -8,11 +7,10 @@
dummy_node_for_type,
get_element_ptr,
getpos,
unwrap_location,
)
from vyper.codegen.lll_node import Encoding, LLLnode
from vyper.codegen.types import TupleType, get_type_for_exact_size
from vyper.exceptions import StateAccessViolation, StructureException, TypeCheckFailure
from vyper.codegen.types import InterfaceType, TupleType, get_type_for_exact_size
from vyper.exceptions import StateAccessViolation, TypeCheckFailure


def _pack_arguments(contract_sig, args, context, pos):
Expand Down Expand Up @@ -206,46 +204,13 @@ def lll_for_external_call(stmt_expr, context):
from vyper.codegen.expr import Expr # TODO rethink this circular import

pos = getpos(stmt_expr)

contract_address = Expr.parse_value_expr(stmt_expr.func.value, context)
value, gas, skip_contract_check = _get_special_kwargs(stmt_expr, context)
args_lll = [Expr(x, context).lll_node for x in stmt_expr.args]

if isinstance(stmt_expr.func, vy_ast.Attribute) and isinstance(
stmt_expr.func.value, vy_ast.Call
):
# e.g. `Foo(address).bar()`

# sanity check
assert len(stmt_expr.func.value.args) == 1
contract_name = stmt_expr.func.value.func.id
contract_address = Expr.parse_value_expr(stmt_expr.func.value.args[0], context)

elif (
isinstance(stmt_expr.func.value, vy_ast.Attribute)
and stmt_expr.func.value.attr in context.globals
# TODO check for self?
and hasattr(context.globals[stmt_expr.func.value.attr].typ, "name")
):
# e.g. `self.foo.bar()`

# sanity check
assert stmt_expr.func.value.value.id == "self", stmt_expr

contract_name = context.globals[stmt_expr.func.value.attr].typ.name
type_ = stmt_expr.func.value._metadata["type"]
var = context.globals[stmt_expr.func.value.attr]
contract_address = unwrap_location(
LLLnode.from_list(
type_.position.position,
typ=var.typ,
location="storage",
pos=pos,
annotation="self." + stmt_expr.func.value.attr,
)
)
else:
# TODO catch this during type checking
raise StructureException("Unsupported operator.", stmt_expr)

assert isinstance(contract_address.typ, InterfaceType)
contract_name = contract_address.typ.name
method_name = stmt_expr.func.attr
contract_sig = context.sigs[contract_name][method_name]

Expand Down
65 changes: 14 additions & 51 deletions vyper/codegen/global_context.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional

from vyper import ast as vy_ast
from vyper.ast.signatures.function_signature import ContractRecord, VariableRecord
from vyper.codegen.types import InterfaceType, parse_type
from vyper.ast.signatures.function_signature import VariableRecord
from vyper.codegen.types import parse_type
from vyper.exceptions import CompilerPanic, InvalidType, StructureException
from vyper.typing import InterfaceImports
from vyper.utils import cached_property
Expand Down Expand Up @@ -133,32 +133,6 @@ def make_contract(node: "vy_ast.InterfaceDef") -> list:
raise StructureException("Invalid contract reference", item)
return _defs

def get_item_name_and_attributes(self, item, attributes):
is_map_invocation = (
isinstance(item, vy_ast.Call) and isinstance(item.func, vy_ast.Name)
) and item.func.id == "HashMap"

if isinstance(item, vy_ast.Name):
return item.id, attributes
elif isinstance(item, vy_ast.AnnAssign):
return self.get_item_name_and_attributes(item.annotation, attributes)
elif isinstance(item, vy_ast.Subscript):
return self.get_item_name_and_attributes(item.value, attributes)
elif is_map_invocation:
if len(item.args) != 2:
raise StructureException(
"Map type expects two type arguments HashMap[type1, type2]", item.func
)
return self.get_item_name_and_attributes(item.args, attributes)
# elif ist
elif isinstance(item, vy_ast.Call) and isinstance(item.func, vy_ast.Name):
attributes[item.func.id] = True
# Raise for multiple args
if len(item.args) != 1:
raise StructureException(f"{item.func.id} expects one arg (the type)")
return self.get_item_name_and_attributes(item.args[0], attributes)
return None, attributes

@staticmethod
def get_call_func_name(item):
if isinstance(item.annotation, vy_ast.Call) and isinstance(
Expand All @@ -167,7 +141,6 @@ def get_call_func_name(item):
return item.annotation.func.id

def add_globals_and_events(self, item):
item_attributes = {"public": False}

if self._nonrentrant_counter:
raise CompilerPanic("Re-entrancy lock was set before all storage slots were defined")
Expand All @@ -180,29 +153,11 @@ def add_globals_and_events(self, item):
if self.get_call_func_name(item) == "constant":
return

item_name, item_attributes = self.get_item_name_and_attributes(item, item_attributes)

# references to `len(self._globals)` are remnants of deprecated code, retained
# to preserve existing interfaces while we complete a larger refactor. location
# and size of storage vars is handled in `vyper.context.validation.data_positions`
if item_name in self._contracts or item_name in self._interfaces:
if self.get_call_func_name(item) == "address":
raise StructureException(
f"Persistent address({item_name}) style contract declarations "
"are not support anymore."
f" Use {item.target.id}: {item_name} instead"
)
self._globals[item.target.id] = ContractRecord(
item.target.id,
len(self._globals),
InterfaceType(item_name),
True,
)
elif self.get_call_func_name(item) == "public":
if isinstance(item.annotation.args[0], vy_ast.Name) and item_name in self._contracts:
typ = InterfaceType(item_name)
else:
typ = self.parse_type(item.annotation.args[0])
if self.get_call_func_name(item) == "public":
typ = self.parse_type(item.annotation.args[0])
self._globals[item.target.id] = VariableRecord(
item.target.id,
len(self._globals),
Expand All @@ -220,19 +175,27 @@ def add_globals_and_events(self, item):
)

elif isinstance(item.annotation, (vy_ast.Name, vy_ast.Call, vy_ast.Subscript)):
typ = self.parse_type(item.annotation)
self._globals[item.target.id] = VariableRecord(
item.target.id,
len(self._globals),
self.parse_type(item.annotation),
typ,
True,
)
else:
raise InvalidType("Invalid global type specified", item)

@property
def interface_names(self):
"""
The set of names which are known to possibly be InterfaceType
"""
return set(self._contracts.keys()) | set(self._interfaces.keys())

def parse_type(self, ast_node):
return parse_type(
ast_node,
sigs=self._contracts,
sigs=self.interface_names,
custom_structs=self._structs,
)

Expand Down
1 change: 1 addition & 0 deletions vyper/codegen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def parse_external_interfaces(external_interfaces, global_ctx):
# Recognizes already-defined structs
sig = FunctionSignature.from_definition(
_def,
sigs=global_ctx.interface_names,
interface_def=True,
constant_override=constant,
custom_structs=global_ctx._structs,
Expand Down
2 changes: 1 addition & 1 deletion vyper/codegen/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def _get_target(self, target):
raise TypeCheckFailure(f"Failed constancy check\n{_dbg_expr}")
return target

target = Expr.parse_variable_location(target, self.context)
target = Expr.parse_pointer_expr(target, self.context)
if (target.location == "storage" and self.context.is_constant()) or not target.mutable:
raise TypeCheckFailure(f"Failed constancy check\n{_dbg_expr}")
return target
Expand Down
2 changes: 2 additions & 0 deletions vyper/codegen/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ def make_struct_type(name, sigs, members, custom_structs):
# Parses an expression representing a type.
# TODO: rename me to "lll_type_from_annotation"
def parse_type(item, sigs, custom_structs):
# sigs: set of interface or contract names in scope
# custom_structs: struct definitions in scope
def _sanity_check(x):
assert x, "typechecker missed this"

Expand Down