Skip to content

Commit

Permalink
fix type of storage variables containing interface (#2697)
Browse files Browse the repository at this point in the history
`global_context.py` had some old helper function called               
`get_item_name_and_attributes` which behaved incorrectly for complex  
types containing interfaces (specifically, an array of interfaces like
`MyInterface[3]`). Luckily, `parse_type` works better now and the code
seems to work after removing `get_item_name_and_attributes`. We also had    
a ContractRecord class which is now unused, so I removed it.          

Also now allow interfaces in interface signatures, and was able to
simplify some code in `lll_for_external_call`
  • Loading branch information
charles-cooper authored Mar 12, 2022
1 parent aa8affb commit d4b1e3f
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 114 deletions.
16 changes: 8 additions & 8 deletions examples/factory/Exchange.vy
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@ interface Factory:
def register(): nonpayable


token: public(address)
factory: address
token: public(ERC20)
factory: Factory


@external
def __init__(_token: address, _factory: address):
def __init__(_token: ERC20, _factory: Factory):
self.token = _token
self.factory = _factory


@external
def initialize():
# Anyone can safely call this function because of EXTCODEHASH
Factory(self.factory).register()
self.factory.register()


# NOTE: This contract restricts trading to only be done by the factory.
Expand All @@ -28,13 +28,13 @@ def initialize():

@external
def receive(_from: address, _amt: uint256):
assert msg.sender == self.factory
success: bool = ERC20(self.token).transferFrom(_from, self, _amt)
assert msg.sender == self.factory.address
success: bool = self.token.transferFrom(_from, self, _amt)
assert success


@external
def transfer(_to: address, _amt: uint256):
assert msg.sender == self.factory
success: bool = ERC20(self.token).transfer(_to, _amt)
assert msg.sender == self.factory.address
success: bool = self.token.transfer(_to, _amt)
assert success
15 changes: 9 additions & 6 deletions examples/factory/Factory.vy
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from vyper.interfaces import ERC20

interface Exchange:
def token() -> address: view
def token() -> ERC20: view
def receive(_from: address, _amt: uint256): nonpayable
def transfer(_to: address, _amt: uint256): nonpayable


exchange_codehash: public(bytes32)
# Maps token addresses to exchange addresses
exchanges: public(HashMap[address, address])
exchanges: public(HashMap[ERC20, Exchange])


@external
Expand All @@ -30,12 +32,13 @@ def register():
# NOTE: Use exchange's token address because it should be globally unique
# NOTE: Should do checks that it hasn't already been set,
# which has to be rectified with any upgrade strategy.
self.exchanges[Exchange(msg.sender).token()] = msg.sender
exchange: Exchange = Exchange(msg.sender)
self.exchanges[exchange.token()] = exchange


@external
def trade(_token1: address, _token2: address, _amt: uint256):
def trade(_token1: ERC20, _token2: ERC20, _amt: uint256):
# Perform a straight exchange of token1 to token 2 (1:1 price)
# NOTE: Any practical implementation would need to solve the price oracle problem
Exchange(self.exchanges[_token1]).receive(msg.sender, _amt)
Exchange(self.exchanges[_token2]).transfer(msg.sender, _amt)
self.exchanges[_token1].receive(msg.sender, _amt)
self.exchanges[_token2].transfer(msg.sender, _amt)
17 changes: 17 additions & 0 deletions tests/parser/syntax/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,23 @@ def test() -> address:
def test():
b: address = self.a.address
""",
"""
interface MyInterface:
def some_func(): nonpayable
my_interface: MyInterface[3]
idx: uint256
@external
def __init__():
self.my_interface[self.idx] = MyInterface(ZERO_ADDRESS)
""",
"""
interface MyInterface:
def kick(): payable
kickers: HashMap[address, MyInterface]
""",
]


Expand Down
5 changes: 0 additions & 5 deletions vyper/ast/signatures/function_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,6 @@ def size(self):
return math.ceil(self.typ.memory_bytes_required / 32)


class ContractRecord(VariableRecord):
def __init__(self, *args):
super(ContractRecord, self).__init__(*args)


@dataclass
class FunctionArg:
name: str
Expand Down
4 changes: 2 additions & 2 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,9 +1081,9 @@ def struct_literals(expr, name, context):
def parse_value_expr(cls, expr, context):
return unwrap_location(cls(expr, context).lll_node)

# Parse an expression that represents an address in memory/calldata or storage.
# Parse an expression that represents a pointer to memory/calldata or storage.
@classmethod
def parse_variable_location(cls, expr, context):
def parse_pointer_expr(cls, expr, context):
o = cls(expr, context).lll_node
if not o.location:
raise StructureException("Looking for a variable location, instead got a value", expr)
Expand Down
47 changes: 6 additions & 41 deletions vyper/codegen/external_call.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import vyper.utils as util
from vyper import ast as vy_ast
from vyper.codegen.abi_encoder import abi_encode
from vyper.codegen.core import (
calculate_type_for_external_return,
Expand All @@ -8,11 +7,10 @@
dummy_node_for_type,
get_element_ptr,
getpos,
unwrap_location,
)
from vyper.codegen.lll_node import Encoding, LLLnode
from vyper.codegen.types import TupleType, get_type_for_exact_size
from vyper.exceptions import StateAccessViolation, StructureException, TypeCheckFailure
from vyper.codegen.types import InterfaceType, TupleType, get_type_for_exact_size
from vyper.exceptions import StateAccessViolation, TypeCheckFailure


def _pack_arguments(contract_sig, args, context, pos):
Expand Down Expand Up @@ -206,46 +204,13 @@ def lll_for_external_call(stmt_expr, context):
from vyper.codegen.expr import Expr # TODO rethink this circular import

pos = getpos(stmt_expr)

contract_address = Expr.parse_value_expr(stmt_expr.func.value, context)
value, gas, skip_contract_check = _get_special_kwargs(stmt_expr, context)
args_lll = [Expr(x, context).lll_node for x in stmt_expr.args]

if isinstance(stmt_expr.func, vy_ast.Attribute) and isinstance(
stmt_expr.func.value, vy_ast.Call
):
# e.g. `Foo(address).bar()`

# sanity check
assert len(stmt_expr.func.value.args) == 1
contract_name = stmt_expr.func.value.func.id
contract_address = Expr.parse_value_expr(stmt_expr.func.value.args[0], context)

elif (
isinstance(stmt_expr.func.value, vy_ast.Attribute)
and stmt_expr.func.value.attr in context.globals
# TODO check for self?
and hasattr(context.globals[stmt_expr.func.value.attr].typ, "name")
):
# e.g. `self.foo.bar()`

# sanity check
assert stmt_expr.func.value.value.id == "self", stmt_expr

contract_name = context.globals[stmt_expr.func.value.attr].typ.name
type_ = stmt_expr.func.value._metadata["type"]
var = context.globals[stmt_expr.func.value.attr]
contract_address = unwrap_location(
LLLnode.from_list(
type_.position.position,
typ=var.typ,
location="storage",
pos=pos,
annotation="self." + stmt_expr.func.value.attr,
)
)
else:
# TODO catch this during type checking
raise StructureException("Unsupported operator.", stmt_expr)

assert isinstance(contract_address.typ, InterfaceType)
contract_name = contract_address.typ.name
method_name = stmt_expr.func.attr
contract_sig = context.sigs[contract_name][method_name]

Expand Down
65 changes: 14 additions & 51 deletions vyper/codegen/global_context.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional

from vyper import ast as vy_ast
from vyper.ast.signatures.function_signature import ContractRecord, VariableRecord
from vyper.codegen.types import InterfaceType, parse_type
from vyper.ast.signatures.function_signature import VariableRecord
from vyper.codegen.types import parse_type
from vyper.exceptions import CompilerPanic, InvalidType, StructureException
from vyper.typing import InterfaceImports
from vyper.utils import cached_property
Expand Down Expand Up @@ -133,32 +133,6 @@ def make_contract(node: "vy_ast.InterfaceDef") -> list:
raise StructureException("Invalid contract reference", item)
return _defs

def get_item_name_and_attributes(self, item, attributes):
is_map_invocation = (
isinstance(item, vy_ast.Call) and isinstance(item.func, vy_ast.Name)
) and item.func.id == "HashMap"

if isinstance(item, vy_ast.Name):
return item.id, attributes
elif isinstance(item, vy_ast.AnnAssign):
return self.get_item_name_and_attributes(item.annotation, attributes)
elif isinstance(item, vy_ast.Subscript):
return self.get_item_name_and_attributes(item.value, attributes)
elif is_map_invocation:
if len(item.args) != 2:
raise StructureException(
"Map type expects two type arguments HashMap[type1, type2]", item.func
)
return self.get_item_name_and_attributes(item.args, attributes)
# elif ist
elif isinstance(item, vy_ast.Call) and isinstance(item.func, vy_ast.Name):
attributes[item.func.id] = True
# Raise for multiple args
if len(item.args) != 1:
raise StructureException(f"{item.func.id} expects one arg (the type)")
return self.get_item_name_and_attributes(item.args[0], attributes)
return None, attributes

@staticmethod
def get_call_func_name(item):
if isinstance(item.annotation, vy_ast.Call) and isinstance(
Expand All @@ -167,7 +141,6 @@ def get_call_func_name(item):
return item.annotation.func.id

def add_globals_and_events(self, item):
item_attributes = {"public": False}

if self._nonrentrant_counter:
raise CompilerPanic("Re-entrancy lock was set before all storage slots were defined")
Expand All @@ -180,29 +153,11 @@ def add_globals_and_events(self, item):
if self.get_call_func_name(item) == "constant":
return

item_name, item_attributes = self.get_item_name_and_attributes(item, item_attributes)

# references to `len(self._globals)` are remnants of deprecated code, retained
# to preserve existing interfaces while we complete a larger refactor. location
# and size of storage vars is handled in `vyper.context.validation.data_positions`
if item_name in self._contracts or item_name in self._interfaces:
if self.get_call_func_name(item) == "address":
raise StructureException(
f"Persistent address({item_name}) style contract declarations "
"are not support anymore."
f" Use {item.target.id}: {item_name} instead"
)
self._globals[item.target.id] = ContractRecord(
item.target.id,
len(self._globals),
InterfaceType(item_name),
True,
)
elif self.get_call_func_name(item) == "public":
if isinstance(item.annotation.args[0], vy_ast.Name) and item_name in self._contracts:
typ = InterfaceType(item_name)
else:
typ = self.parse_type(item.annotation.args[0])
if self.get_call_func_name(item) == "public":
typ = self.parse_type(item.annotation.args[0])
self._globals[item.target.id] = VariableRecord(
item.target.id,
len(self._globals),
Expand All @@ -220,19 +175,27 @@ def add_globals_and_events(self, item):
)

elif isinstance(item.annotation, (vy_ast.Name, vy_ast.Call, vy_ast.Subscript)):
typ = self.parse_type(item.annotation)
self._globals[item.target.id] = VariableRecord(
item.target.id,
len(self._globals),
self.parse_type(item.annotation),
typ,
True,
)
else:
raise InvalidType("Invalid global type specified", item)

@property
def interface_names(self):
"""
The set of names which are known to possibly be InterfaceType
"""
return set(self._contracts.keys()) | set(self._interfaces.keys())

def parse_type(self, ast_node):
return parse_type(
ast_node,
sigs=self._contracts,
sigs=self.interface_names,
custom_structs=self._structs,
)

Expand Down
1 change: 1 addition & 0 deletions vyper/codegen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def parse_external_interfaces(external_interfaces, global_ctx):
# Recognizes already-defined structs
sig = FunctionSignature.from_definition(
_def,
sigs=global_ctx.interface_names,
interface_def=True,
constant_override=constant,
custom_structs=global_ctx._structs,
Expand Down
2 changes: 1 addition & 1 deletion vyper/codegen/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def _get_target(self, target):
raise TypeCheckFailure(f"Failed constancy check\n{_dbg_expr}")
return target

target = Expr.parse_variable_location(target, self.context)
target = Expr.parse_pointer_expr(target, self.context)
if (target.location == "storage" and self.context.is_constant()) or not target.mutable:
raise TypeCheckFailure(f"Failed constancy check\n{_dbg_expr}")
return target
Expand Down
2 changes: 2 additions & 0 deletions vyper/codegen/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ def make_struct_type(name, sigs, members, custom_structs):
# Parses an expression representing a type.
# TODO: rename me to "lll_type_from_annotation"
def parse_type(item, sigs, custom_structs):
# sigs: set of interface or contract names in scope
# custom_structs: struct definitions in scope
def _sanity_check(x):
assert x, "typechecker missed this"

Expand Down

0 comments on commit d4b1e3f

Please sign in to comment.