Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow any ABI-encodable type as args #2398

Closed
wants to merge 11 commits into from
7 changes: 5 additions & 2 deletions docs/built-in-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ Utilities
Once this function has seen more use we provisionally plan to put it into the ``ethereum.abi`` namespace.

* ``*args``: Arbitrary arguments
* ``ensure_tuple``: If set to True, ensures that even a single argument is encoded as a tuple. In other words, ``bytes`` gets encoded as ``(bytes,)``. This is the calling convention for Vyper and Solidity functions. Except for very specific use cases, this should be set to True. Must be a literal.
* ``ensure_tuple``: If set to True, ensures that even a single argument is encoded as a tuple. In other words, ``bytes`` gets encoded as ``(bytes,)``, and ``(bytes,)`` gets encoded as ``((bytes,),)`` This is the calling convention for Vyper and Solidity functions. Except for very specific use cases, this should be set to True. Must be a literal.

Returns a bytestring whose max length is determined by the arguments. For example, encoding a ``Bytes[32]`` results in a ``Bytes[64]`` (first word is the length of the bytestring variable).

Expand All @@ -646,4 +646,7 @@ Utilities
.. code-block:: python

>>> ExampleContract.foo().hex()
"0000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000033233340000000000000000000000000000000000000000000000000000000000"
"0000000000000000000000000000000000000000000000000000000000000001"
"0000000000000000000000000000000000000000000000000000000000000040"
"0000000000000000000000000000000000000000000000000000000000000003"
"3233340000000000000000000000000000000000000000000000000000000000"
24 changes: 8 additions & 16 deletions tests/functional/codegen/test_struct_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,26 @@ def test_nested_tuple(get_contract):
code = """
struct Animal:
location: address
fur: uint256
fur: String[32]

struct Human:
location: address
height: uint256
animal: Animal

@external
def return_nested_tuple() -> (Animal, Human):
animal: Animal = Animal({
location: 0x1234567890123456789012345678901234567890,
fur: 123
})
human: Human = Human({
location: 0x1234567890123456789012345678900000000000,
height: 456
})
def modify_nested_tuple(_human: Human) -> Human:
human: Human = _human

# do stuff, edit the structs
animal.fur += 1
human.height += 1
human.animal.fur = slice(concat(human.animal.fur, " is great"), 0, 32)

return animal, human
return human
"""
c = get_contract(code)
addr1 = "0x1234567890123456789012345678901234567890"
addr2 = "0x1234567890123456789012345678900000000000"
assert c.return_nested_tuple() == [(addr1, 124), (addr2, 457)]

#assert c.modify_nested_tuple([addr1, 123], [addr2, 456]) == [[addr1, 124], [addr2, 457]]
assert c.modify_nested_tuple({"location": addr1, "animal": {"location": addr2, "fur": "wool"}}) == [(addr1, (addr2, "wool is great"))]

@pytest.mark.parametrize("string", ["a", "abc", "abcde", "potato"])
def test_string_inside_tuple(get_contract, string):
Expand Down
25 changes: 0 additions & 25 deletions tests/parser/exceptions/test_argument_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,31 +89,6 @@ def foo():
for i in range(1, 2, 3, 4):
pass
""",
"""
struct Foo:
a: Bytes[32]
@external
def foo(a: Foo):
pass
""",
"""
struct Foo:
a: String[32]
@external
def foo(a: Foo):
pass
""",
"""
struct Foo:
b: uint256
a: String[32]
@external
def foo(a: Foo):
pass
""",
]


Expand Down
5 changes: 3 additions & 2 deletions vyper/builtin_functions/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1491,7 +1491,7 @@ def build_LLL(self, expr, context):


def get_create_forwarder_to_bytecode():
# FLAG cyclic import?
# NOTE cyclic import?
from vyper.lll.compile_lll import assembly_to_evm

loader_asm = [
Expand Down Expand Up @@ -1746,7 +1746,8 @@ class ABIEncode(_SimpleBuiltinFunction):
# to handle varargs.)
# explanation of ensure_tuple:
# default is to force even a single value into a tuple,
# e.g. _abi_encode(bytes) -> abi_encode((bytes,))
# e.g. _abi_encode(bytes) -> _abi_encode((bytes,))
# _abi_encode((bytes,)) -> _abi_encode(((bytes,),))
# this follows the encoding convention for functions:
# ://docs.soliditylang.org/en/v0.8.6/abi-spec.html#function-selector-and-argument-encoding
# if this is turned off, then bytes will be encoded as bytes.
Expand Down
52 changes: 29 additions & 23 deletions vyper/old_codegen/abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ def selector_name(self):
# Whether the type is a tuple at the ABI level.
# (This is important because if it does, it needs an offset.
# Compare the difference in encoding between `bytes` and `(bytes,)`.)
def is_tuple(self):
raise NotImplementedError("ABIType.is_tuple")

def is_complex_type(self):
raise NotImplementedError("ABIType.is_complex_type")

# uint<M>: unsigned integer type of M bits, 0 < M <= 256, M % 8 == 0. e.g. uint32, uint8, uint256.
# int<M>: two’s complement signed integer type of M bits, 0 < M <= 256, M % 8 == 0.
Expand All @@ -71,7 +70,7 @@ def static_size(self):
def selector_name(self):
return ("" if self.signed else "u") + f"int{self.m_bits}"

def is_tuple(self):
def is_complex_type(self):
return False


Expand Down Expand Up @@ -122,7 +121,7 @@ def static_size(self):
def selector_name(self):
return ("" if self.signed else "u") + "fixed{self.m_bits}x{self.n_places}"

def is_tuple(self):
def is_complex_type(self):
return False


Expand All @@ -143,7 +142,7 @@ def static_size(self):
def selector_name(self):
return f"bytes{self.m_bytes}"

def is_tuple(self):
def is_complex_type(self):
return False


Expand Down Expand Up @@ -178,7 +177,7 @@ def dynamic_size_bound(self):
def selector_name(self):
return f"{self.subtyp.selector_name()}[{self.m_elems}]"

def is_tuple(self):
def is_complex_type(self):
return True


Expand All @@ -204,7 +203,7 @@ def dynamic_size_bound(self):
def selector_name(self):
return "bytes"

def is_tuple(self):
def is_complex_type(self):
return False


Expand Down Expand Up @@ -233,7 +232,7 @@ def dynamic_size_bound(self):
def selector_name(self):
return f"{self.subtyp.selector_name()}[]"

def is_tuple(self):
def is_complex_type(self):
return False


Expand All @@ -250,7 +249,7 @@ def static_size(self):
def dynamic_size_bound(self):
return sum([t.dynamic_size_bound() for t in self.subtyps])

def is_tuple(self):
def is_complex_type(self):
return True


Expand Down Expand Up @@ -317,14 +316,6 @@ def abi_type_of2(t: vy.BasePrimitive) -> ABIType:
raise CompilerPanic(f"Unrecognized type {t}")


# there are a lot of places in the calling convention where a tuple
# must be passed, so here's a convenience function for that.
def ensure_tuple(abi_typ):
if not abi_typ.is_tuple():
return ABI_Tuple([abi_typ])
return abi_typ


# turn an lll node into a list, based on its type.
def o_list(lll_node, pos=None):
lll_t = lll_node.typ
Expand Down Expand Up @@ -382,16 +373,29 @@ def abi_encode(dst, lll_node, pos=None, bufsz=None, returns_len=False):
if bufsz is not None and bufsz < 32 * size_bound:
raise CompilerPanic("buffer provided to abi_encode not large enough")


# fastpath: if there is no dynamic data, we can optimize the
# encoding by using make_setter, since our memory encoding happens
# to be identical to the ABI encoding.
if not parent_abi_t.is_dynamic():
# cast the output buffer to something that make_setter accepts
dst = LLLnode(dst, typ=lll_node.typ, location="memory")
lll_ret = ["seq_unchecked", make_setter(dst, lll_node, "memory", pos)]
lll_ret.append(parent_abi_t.embedded_static_size())
return LLLnode.from_list(lll_ret, pos=pos, annotation=f"abi_encode {lll_node.typ}")


lll_ret = ["seq"]
dyn_ofst = "dyn_ofst" # current offset in the dynamic section
dst_begin = "dst" # pointer to beginning of buffer
dst_loc = "dst_loc" # pointer to write location in static section
os = o_list(lll_node, pos=pos)


for i, o in enumerate(os):
abi_t = abi_type_of(o.typ)

if parent_abi_t.is_tuple():
if parent_abi_t.is_complex_type():
if abi_t.is_dynamic():
lll_ret.append(["mstore", dst_loc, dyn_ofst])
# recurse
Expand All @@ -408,9 +412,11 @@ def abi_encode(dst, lll_node, pos=None, bufsz=None, returns_len=False):

elif isinstance(o.typ, BaseType):
d = LLLnode(dst_loc, typ=o.typ, location="memory")
# call into make_setter routine
lll_ret.append(make_setter(d, o, location=d.location, pos=pos))
elif isinstance(o.typ, ByteArrayLike):
d = LLLnode.from_list(dst_loc, typ=o.typ, location="memory")
# call into make_setter routinme
lll_ret.append(["seq", make_setter(d, o, location=d.location, pos=pos), zero_pad(d)])
else:
raise CompilerPanic(f"unreachable type: {o.typ}")
Expand All @@ -425,7 +431,7 @@ def abi_encode(dst, lll_node, pos=None, bufsz=None, returns_len=False):
if returns_len:
if not parent_abi_t.is_dynamic():
lll_ret.append(parent_abi_t.embedded_static_size())
elif parent_abi_t.is_tuple():
elif parent_abi_t.is_complex_type():
lll_ret.append("dyn_ofst")
elif isinstance(lll_node.typ, ByteArrayLike):
# for abi purposes, return zero-padded length
Expand All @@ -434,15 +440,15 @@ def abi_encode(dst, lll_node, pos=None, bufsz=None, returns_len=False):
else:
raise CompilerPanic("unknown type {lll_node.typ}")

if not (parent_abi_t.is_dynamic() and parent_abi_t.is_tuple()):
if not (parent_abi_t.is_dynamic() and parent_abi_t.is_complex_type()):
pass # optimize out dyn_ofst allocation if we don't need it
else:
dyn_section_start = parent_abi_t.static_size()
lll_ret = ["with", "dyn_ofst", dyn_section_start, lll_ret]

lll_ret = ["with", dst_begin, dst, ["with", dst_loc, dst_begin, lll_ret]]

return LLLnode.from_list(lll_ret, pos=pos)
return LLLnode.from_list(lll_ret, pos=pos, annotation=f"abi_encode {lll_node.typ}")


# lll_node is the destination LLL item, src is the input buffer.
Expand All @@ -456,7 +462,7 @@ def abi_decode(lll_node, src, pos=None):
for i, o in enumerate(os):
abi_t = abi_type_of(o.typ)
src_loc = LLLnode("src_loc", typ=o.typ, location=src.location)
if parent_abi_t.is_tuple():
if parent_abi_t.is_complex_type():
if abi_t.is_dynamic():
child_loc = ["add", "src", unwrap_location(src_loc)]
child_loc = LLLnode.from_list(child_loc, typ=o.typ, location=src.location)
Expand Down
5 changes: 5 additions & 0 deletions vyper/old_codegen/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,18 @@ def __init__(
# Not intended to be accessed directly
self.memory_allocator = memory_allocator

self._callee_frame_sizes = []

# Intermented values, used for internal IDs
self._internal_var_iter = 0
self._scope_id_iter = 0

def is_constant(self):
return self.constancy is Constancy.Constant or self.in_assertion or self.in_range_expr

def register_callee(frame_size):
self._callee_frame_sizes.append(frame_size)

#
# Context Managers
# - Context managers are used to ensure proper wrapping of scopes and context states.
Expand Down
9 changes: 5 additions & 4 deletions vyper/old_codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,10 +994,11 @@ def parse_List(self):
return LLLnode.from_list(lll_node, typ=typ, pos=getpos(self.expr))

def parse_Tuple(self):
call_lll, multi_lll = parse_sequence(self.expr, self.expr.elements, self.context)
typ = TupleType([x.typ for x in multi_lll], is_literal=True)
multi_lll = LLLnode.from_list(["multi"] + multi_lll, typ=typ, pos=getpos(self.expr))
if not call_lll:
#call_lll, multi_lll = parse_sequence(self.expr, self.expr.elements, self.context)
tuple_elements = [Expr(x, self.context).lll_node for x in self.expr.elements]
typ = TupleType([x.typ for x in tuple_elements], is_literal=True)
multi_lll = LLLnode.from_list(["multi"] + tuple_elements, typ=typ, pos=getpos(self.expr))
if True: # if not call_lll:
return multi_lll

lll_node = ["seq_unchecked"] + call_lll + [multi_lll]
Expand Down
4 changes: 2 additions & 2 deletions vyper/old_codegen/function_definitions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .parse_function import ( # noqa
from .common import ( # noqa
is_default_func,
is_initializer,
parse_function,
generate_lll_for_function,
)
Loading