Skip to content

Commit

Permalink
fix: allow interfaces to make external calls directly from mappings
Browse files Browse the repository at this point in the history
  • Loading branch information
fubuloubu committed Mar 23, 2021
1 parent c7089a0 commit 13458fe
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 13 deletions.
7 changes: 6 additions & 1 deletion vyper/parser/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,13 @@ def parse_Name(self):

# x.y or x[5]
def parse_Attribute(self):
# x.address
if self.expr.attr == "address":
addr = Expr.parse_value_expr(self.expr.value, self.context)
if is_base_type(addr.typ, "address"):
return addr
# x.balance: balance of address x
if self.expr.attr == "balance":
elif self.expr.attr == "balance":
addr = Expr.parse_value_expr(self.expr.value, self.context)
if is_base_type(addr.typ, "address"):
if (
Expand Down
50 changes: 50 additions & 0 deletions vyper/parser/external_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
BaseType,
ByteArrayLike,
ListType,
MappingType,
TupleLike,
get_size_of_type,
get_static_size_of_type,
Expand Down Expand Up @@ -241,5 +242,54 @@ def make_external_call(stmt_expr, context):
gas=gas,
)

elif (
isinstance(stmt_expr.func.value, vy_ast.Name)
and stmt_expr.func.value.id in context.vars
and context.vars[stmt_expr.func.value.id].typ.typ == "address"
and context.vars[stmt_expr.func.value.id].typ.contract_type
):

var = context.vars[stmt_expr.func.value.id]
contract_address = unwrap_location(
LLLnode.from_list(
var.pos,
typ=var.typ,
location=var.location,
pos=getpos(stmt_expr),
annotation=stmt_expr.func.value.id,
)
)

return external_call(
stmt_expr,
context,
var.typ.contract_type,
contract_address,
pos=getpos(stmt_expr),
value=value,
gas=gas,
)

elif (
isinstance(stmt_expr.func.value, vy_ast.Subscript)
and isinstance(stmt_expr.func.value.value, vy_ast.Attribute)
and stmt_expr.func.value.value.attr in context.globals
and isinstance(context.globals[stmt_expr.func.value.value.attr].typ, MappingType)
and context.globals[stmt_expr.func.value.value.attr].typ.valuetype.typ == "address"
and context.globals[stmt_expr.func.value.value.attr].typ.valuetype.contract_type
):

contract_address = Expr.parse_value_expr(stmt_expr.func.value, context)

return external_call(
stmt_expr,
context,
context.globals[stmt_expr.func.value.value.attr].typ.valuetype.contract_type,
contract_address,
pos=getpos(stmt_expr),
value=value,
gas=gas,
)

else:
raise StructureException("Unsupported operator.", stmt_expr)
11 changes: 9 additions & 2 deletions vyper/parser/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,12 @@ def add_globals_and_events(self, item):
if isinstance(item.annotation.args[0], vy_ast.Name) and item_name in self._contracts:
typ = InterfaceType(item_name)
else:
typ = parse_type(item.annotation.args[0], "storage", custom_structs=self._structs,)
typ = parse_type(
item.annotation.args[0],
"storage",
sigs=self._contracts,
custom_structs=self._structs,
)
self._globals[item.target.id] = VariableRecord(
item.target.id, len(self._globals), typ, True,
)
Expand All @@ -221,7 +226,9 @@ def add_globals_and_events(self, item):
self._globals[item.target.id] = VariableRecord(
item.target.id,
len(self._globals),
parse_type(item.annotation, "storage", custom_structs=self._structs,),
parse_type(
item.annotation, "storage", sigs=self._contracts, custom_structs=self._structs,
),
True,
)
else:
Expand Down
15 changes: 14 additions & 1 deletion vyper/parser/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ def parse_Name(self):

def parse_AnnAssign(self):
typ = parse_type(
self.stmt.annotation, location="memory", custom_structs=self.context.structs,
self.stmt.annotation,
location="memory",
sigs=self.context.sigs,
custom_structs=self.context.structs,
)
varname = self.stmt.target.id
pos = self.context.new_variable(varname, typ, pos=self.stmt)
Expand Down Expand Up @@ -422,6 +425,16 @@ def parse_Return(self):
if isinstance(sub.typ, BaseType):
sub = unwrap_location(sub)

# HACK: For whatever reason, this disconnect is difficult to match up
if (
self.context.return_type != sub.typ
and self.context.return_type.typ == sub.typ.typ == "address"
and sub.typ.contract_type
and not self.context.return_type.contract_type
):
# Implicitly cast interface types to addresses
sub.typ.contract_type = None

if self.context.return_type != sub.typ and not sub.typ.is_literal:
return
elif sub.typ.is_literal and (
Expand Down
28 changes: 19 additions & 9 deletions vyper/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,15 @@ def eq(self, other: "NodeType") -> bool: # pragma: no cover

# Data structure for a type that represents a 32-byte object
class BaseType(NodeType):
def __init__(
self, typ, unit=False, positional=False, override_signature=False, is_literal=False
):
def __init__(self, typ, unit=False, positional=False, is_literal=False, contract_type=None):
self.typ = typ
if unit or positional:
raise CompilerPanic("Units are no longer supported")
self.override_signature = override_signature
self.contract_type = contract_type
self.is_literal = is_literal

def eq(self, other):
return self.typ == other.typ
return self.typ == other.typ and self.contract_type == other.contract_type

def __repr__(self):
return str(self.typ)
Expand Down Expand Up @@ -203,6 +201,8 @@ def parse_type(item, location, sigs=None, custom_structs=None):
return BaseType(item.id)
elif (custom_structs is not None) and (item.id in custom_structs):
return make_struct_type(item.id, location, custom_structs[item.id], custom_structs,)
elif sigs and item.id in sigs:
return BaseType("address", contract_type=item.id)
else:
raise InvalidType("Invalid base type: " + item.id, item)
# Units, e.g. num (1/sec) or contracts
Expand Down Expand Up @@ -233,13 +233,21 @@ def parse_type(item, location, sigs=None, custom_structs=None):
# List
else:
return ListType(
parse_type(item.value, location, custom_structs=custom_structs,), n_val,
parse_type(item.value, location, sigs=sigs, custom_structs=custom_structs,),
n_val,
)
elif item.value.id in ("HashMap",) and isinstance(item.slice.value, vy_ast.Tuple):
keytype = parse_type(item.slice.value.elements[0], None, custom_structs=custom_structs,)
keytype = parse_type(
item.slice.value.elements[0], None, sigs=sigs, custom_structs=custom_structs,
)
return MappingType(
keytype,
parse_type(item.slice.value.elements[1], location, custom_structs=custom_structs,),
parse_type(
item.slice.value.elements[1],
location,
sigs=sigs,
custom_structs=custom_structs,
),
)
# Mappings, e.g. num[address]
else:
Expand All @@ -253,7 +261,9 @@ def parse_type(item, location, sigs=None, custom_structs=None):
)
raise InvalidType("Invalid type", item)
elif isinstance(item, vy_ast.Tuple):
members = [parse_type(x, location, custom_structs=custom_structs) for x in item.elements]
members = [
parse_type(x, location, sigs=sigs, custom_structs=custom_structs) for x in item.elements
]
return TupleType(members)
else:
raise InvalidType("Invalid type", item)
Expand Down

0 comments on commit 13458fe

Please sign in to comment.