diff --git a/tests/parser/features/external_contracts/test_external_contract_calls.py b/tests/parser/features/external_contracts/test_external_contract_calls.py index 58f2e8ae5f..d12e6b6bc7 100644 --- a/tests/parser/features/external_contracts/test_external_contract_calls.py +++ b/tests/parser/features/external_contracts/test_external_contract_calls.py @@ -2324,3 +2324,62 @@ def foo(_addr: address, _addr2: address) -> int128: assert c2.foo(c1.address, c1.address) == 123 assert_tx_failed(lambda: c2.foo(c1.address, "0x1234567890123456789012345678901234567890")) + + +def test_default_override(get_contract, assert_tx_failed): + bad_erc20_code = """ +@external +def transfer(receiver: address, amount: uint256): + pass + """ + + code = """ +from vyper.interfaces import ERC20 +@external +def safeTransfer(erc20: ERC20, receiver: address, amount: uint256): + assert erc20.transfer(receiver, amount, default_return_value=True) + +@external +def transferBorked(erc20: ERC20, receiver: address, amount: uint256): + assert erc20.transfer(receiver, amount) + """ + bad_erc20 = get_contract(bad_erc20_code) + c = get_contract(code) + + # demonstrate transfer failing + assert_tx_failed(lambda: c.transferBorked(bad_erc20.address, c.address, 0)) + # would fail without default_return_value + c.safeTransfer(bad_erc20.address, c.address, 0) + + +def test_default_override2(get_contract, assert_tx_failed): + bad_code_1 = """ +@external +def return_64_bytes() -> bool: + return True + """ + + bad_code_2 = """ +@external +def return_64_bytes(): + pass + """ + + code = """ +struct BoolPair: + x: bool + y: bool +interface Foo: + def return_64_bytes() -> BoolPair: nonpayable +@external +def bar(foo: Foo): + t: BoolPair = foo.return_64_bytes(default_return_value=BoolPair({x: True, y:True})) + assert t.x and t.y + """ + bad_1 = get_contract(bad_code_1) + bad_2 = get_contract(bad_code_2) + c = get_contract(code) + + # fails due to returndatasize being nonzero but also lt 64 + assert_tx_failed(lambda: c.bar(bad_1.address)) + c.bar(bad_2.address) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 15d9356fd9..11b8e5ba29 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -937,6 +937,7 @@ def parse_Call(self): return arg_ir elif isinstance(self.expr.func, vy_ast.Attribute) and self.expr.func.attr == "pop": + # TODO consider moving this to builtins darray = Expr(self.expr.func.value, self.context).ir_node assert len(self.expr.args) == 0 assert isinstance(darray.typ, DArrayType) @@ -946,10 +947,12 @@ def parse_Call(self): ) elif ( + # TODO use expr.func.type.is_internal once + # type annotations are consistently available isinstance(self.expr.func, vy_ast.Attribute) and isinstance(self.expr.func.value, vy_ast.Name) and self.expr.func.value.id == "self" - ): # noqa: E501 + ): return self_call.ir_for_self_call(self.expr, self.context) else: return external_call.ir_for_external_call(self.expr, self.context) diff --git a/vyper/codegen/external_call.py b/vyper/codegen/external_call.py index c70c851f5b..f764296369 100644 --- a/vyper/codegen/external_call.py +++ b/vyper/codegen/external_call.py @@ -1,3 +1,5 @@ +from dataclasses import dataclass + import vyper.utils as util from vyper.address_space import MEMORY from vyper.codegen.abi_encoder import abi_encode @@ -8,25 +10,38 @@ dummy_node_for_type, make_setter, needs_clamp, + unwrap_location, + wrap_value_for_external_return, ) from vyper.codegen.ir_node import Encoding, IRnode from vyper.codegen.types import InterfaceType, TupleType, get_type_for_exact_size from vyper.codegen.types.convert import new_type_to_old_type -from vyper.exceptions import StateAccessViolation, TypeCheckFailure +from vyper.exceptions import TypeCheckFailure +from vyper.semantics.types.function import StateMutability + + +@dataclass +class _CallKwargs: + value: IRnode + gas: IRnode + skip_contract_check: bool + default_return_value: IRnode -def _pack_arguments(contract_sig, args, context): +def _pack_arguments(fn_type, args, context): # abi encoding just treats all args as a big tuple args_tuple_t = TupleType([x.typ for x in args]) args_as_tuple = IRnode.from_list(["multi"] + [x for x in args], typ=args_tuple_t) args_abi_t = args_tuple_t.abi_type # sanity typecheck - make sure the arguments can be assigned - dst_tuple_t = TupleType([arg.typ for arg in contract_sig.args][: len(args)]) + dst_tuple_t = TupleType( + [new_type_to_old_type(typ) for typ in fn_type.arguments.values()][: len(args)] + ) check_assign(dummy_node_for_type(dst_tuple_t), args_as_tuple) - if contract_sig.return_type is not None: - return_abi_t = calculate_type_for_external_return(contract_sig.return_type).abi_type + if fn_type.return_type is not None: + return_abi_t = calculate_type_for_external_return(fn_type.return_type).abi_type # we use the same buffer for args and returndata, # so allocate enough space here for the returndata too. @@ -42,7 +57,7 @@ def _pack_arguments(contract_sig, args, context): args_ofst = buf + 28 args_len = args_abi_t.size_bound() + 4 - abi_signature = contract_sig.name + dst_tuple_t.abi_type.selector_name() + abi_signature = fn_type.name + dst_tuple_t.abi_type.selector_name() # layout: # 32 bytes | args @@ -51,31 +66,26 @@ def _pack_arguments(contract_sig, args, context): # if we were only targeting constantinople, we could align # to buf (and also keep code size small) by using # (mstore buf (shl signature.method_id 224)) - mstore_method_id = [["mstore", buf, util.abi_method_id(abi_signature)]] + pack_args = ["seq"] + pack_args.append(["mstore", buf, util.abi_method_id(abi_signature)]) - if len(args) == 0: - encode_args = ["pass"] - else: - encode_args = abi_encode(buf + 32, args_as_tuple, context, bufsz=buflen) + if len(args) != 0: + pack_args.append(abi_encode(buf + 32, args_as_tuple, context, bufsz=buflen)) - return buf, mstore_method_id + [encode_args], args_ofst, args_len + return buf, pack_args, args_ofst, args_len -def _unpack_returndata(buf, contract_sig, skip_contract_check, context, expr): - # expr.func._metadata["type"].return_type is more accurate - # than contract_sig.return_type in the case of JSON interfaces. - ast_return_t = expr.func._metadata["type"].return_type +def _unpack_returndata(buf, fn_type, call_kwargs, context, expr): + ast_return_t = fn_type.return_type if ast_return_t is None: return ["pass"], 0, 0 - # sanity check return_t = new_type_to_old_type(ast_return_t) - check_assign(dummy_node_for_type(return_t), dummy_node_for_type(contract_sig.return_type)) - return_t = calculate_type_for_external_return(return_t) + wrapped_return_t = calculate_type_for_external_return(return_t) - abi_return_t = return_t.abi_type + abi_return_t = wrapped_return_t.abi_type min_return_size = abi_return_t.min_size() max_return_size = abi_return_t.size_bound() @@ -84,139 +94,125 @@ def _unpack_returndata(buf, contract_sig, skip_contract_check, context, expr): ret_ofst = buf ret_len = max_return_size - # revert when returndatasize is not in bounds - ret = [] - # runtime: min_return_size <= returndatasize - if not skip_contract_check: - ret += [["assert", ["ge", "returndatasize", min_return_size]]] - encoding = Encoding.ABI buf = IRnode.from_list( buf, - typ=return_t, + typ=wrapped_return_t, location=MEMORY, encoding=encoding, annotation=f"{expr.node_source_code} returndata buffer", ) - assert isinstance(return_t, TupleType) + unpacker = ["seq"] + + # revert when returndatasize is not in bounds + # (except when return_override is provided.) + if not call_kwargs.skip_contract_check: + unpacker.append(["assert", ["ge", "returndatasize", min_return_size]]) + + assert isinstance(wrapped_return_t, TupleType) + # unpack strictly - if needs_clamp(return_t, encoding): - buf2 = IRnode.from_list( - context.new_internal_variable(return_t), typ=return_t, location=MEMORY - ) + if needs_clamp(wrapped_return_t, encoding): + return_buf = context.new_internal_variable(wrapped_return_t) + return_buf = IRnode.from_list(return_buf, typ=wrapped_return_t, location=MEMORY) - ret.append(make_setter(buf2, buf)) - ret.append(buf2) + # note: make_setter does ABI decoding and clamps + unpacker.append(make_setter(return_buf, buf)) else: - ret.append(buf) + return_buf = buf - return ret, ret_ofst, ret_len + if call_kwargs.default_return_value is not None: + # if returndatasize == 0: + # copy return override to buf + # else: + # do the other stuff + override_value = wrap_value_for_external_return(call_kwargs.default_return_value) + stomp_return_buffer = make_setter(return_buf, override_value) + unpacker = ["if", ["eq", "returndatasize", 0], stomp_return_buffer, unpacker] -def _external_call_helper( - contract_address, - contract_sig, - args_ir, - context, - value=None, - gas=None, - skip_contract_check=None, - expr=None, -): + unpacker = ["seq", unpacker, return_buf] - if value is None: - value = 0 - if gas is None: - gas = "gas" - if skip_contract_check is None: - skip_contract_check = False + return unpacker, ret_ofst, ret_len - # sanity check - assert len(contract_sig.base_args) <= len(args_ir) <= len(contract_sig.args) - if context.is_constant() and contract_sig.mutability not in ("view", "pure"): - # TODO is this already done in type checker? - raise StateAccessViolation( - f"May not call state modifying function '{contract_sig.name}' " - f"within {context.pp_constancy()}.", - expr, - ) +def _parse_kwargs(call_expr, context): + from vyper.codegen.expr import Expr # TODO rethink this circular import - sub = ["seq"] + def _bool(x): + assert x.value in (0, 1), "type checker missed this" + return bool(x.value) - buf, arg_packer, args_ofst, args_len = _pack_arguments(contract_sig, args_ir, context) + # note: codegen for kwarg values in AST order + call_kwargs = {kw.arg: Expr(kw.value, context).ir_node for kw in call_expr.keywords} - ret_unpacker, ret_ofst, ret_len = _unpack_returndata( - buf, contract_sig, skip_contract_check, context, expr + ret = _CallKwargs( + value=unwrap_location(call_kwargs.pop("value", IRnode(0))), + gas=unwrap_location(call_kwargs.pop("gas", IRnode("gas"))), + skip_contract_check=_bool(call_kwargs.pop("skip_contract_check", IRnode(0))), + default_return_value=call_kwargs.pop("default_return_value", None), ) - sub += arg_packer + if len(call_kwargs) != 0: + raise TypeCheckFailure(f"Unexpected keyword arguments: {call_kwargs}") - if contract_sig.return_type is None and not skip_contract_check: - # if we do not expect return data, check that a contract exists at the - # target address. we must perform this check BEFORE the call because - # the contract might selfdestruct. on the other hand we can omit this - # when we _do_ expect return data because we later check - # `returndatasize` (that check works even if the contract - # selfdestructs). - sub.append(["assert", ["extcodesize", contract_address]]) + return ret - if context.is_constant() or contract_sig.mutability in ("view", "pure"): - call_op = ["staticcall", gas, contract_address, args_ofst, args_len, ret_ofst, ret_len] - else: - call_op = ["call", gas, contract_address, value, args_ofst, args_len, ret_ofst, ret_len] - sub.append(check_external_call(call_op)) +def ir_for_external_call(call_expr, context): + from vyper.codegen.expr import Expr # TODO rethink this circular import - if contract_sig.return_type is not None: - sub += ret_unpacker + contract_address = Expr.parse_value_expr(call_expr.func.value, context) + call_kwargs = _parse_kwargs(call_expr, context) + args_ir = [Expr(x, context).ir_node for x in call_expr.args] - return IRnode.from_list(sub, typ=contract_sig.return_type, location=MEMORY) + assert isinstance(contract_address.typ, InterfaceType) + # expr.func._metadata["type"].return_type is more accurate + # than fn_sig.return_type in the case of JSON interfaces. + fn_type = call_expr.func._metadata["type"] -def _get_special_kwargs(stmt_expr, context): - from vyper.codegen.expr import Expr # TODO rethink this circular import + # sanity check + assert fn_type.min_arg_count <= len(args_ir) <= fn_type.max_arg_count - value, gas, skip_contract_check = None, None, None - for kw in stmt_expr.keywords: - if kw.arg == "gas": - gas = Expr.parse_value_expr(kw.value, context) - elif kw.arg == "value": - value = Expr.parse_value_expr(kw.value, context) - elif kw.arg == "skip_contract_check": - skip_contract_check = kw.value.value - assert isinstance(skip_contract_check, bool), "type checker missed this" - else: - raise TypeCheckFailure("Unexpected keyword argument") + ret = ["seq"] - # TODO maybe return a small dataclass to reduce verbosity - return value, gas, skip_contract_check + buf, arg_packer, args_ofst, args_len = _pack_arguments(fn_type, args_ir, context) + ret_unpacker, ret_ofst, ret_len = _unpack_returndata( + buf, fn_type, call_kwargs, context, call_expr + ) -def ir_for_external_call(stmt_expr, context): - from vyper.codegen.expr import Expr # TODO rethink this circular import + ret += arg_packer - contract_address = Expr.parse_value_expr(stmt_expr.func.value, context) - value, gas, skip_contract_check = _get_special_kwargs(stmt_expr, context) - args_ir = [Expr(x, context).ir_node for x in stmt_expr.args] + if fn_type.return_type is None and not call_kwargs.skip_contract_check: + # if we do not expect return data, check that a contract exists at the + # target address. we must perform this check BEFORE the call because + # the contract might selfdestruct. on the other hand we can omit this + # when we _do_ expect return data because we later check + # `returndatasize` (that check works even if the contract + # selfdestructs). + ret.append(["assert", ["extcodesize", contract_address]]) - assert isinstance(contract_address.typ, InterfaceType) - contract_name = contract_address.typ.name - method_name = stmt_expr.func.attr - contract_sig = context.sigs[contract_name][method_name] - - ret = _external_call_helper( - contract_address, - contract_sig, - args_ir, - context, - value=value, - gas=gas, - skip_contract_check=skip_contract_check, - expr=stmt_expr, - ) - ret.annotation = stmt_expr.get("node_source_code") + gas = call_kwargs.gas + value = call_kwargs.value - return ret + use_staticcall = fn_type.mutability in (StateMutability.VIEW, StateMutability.PURE) + if context.is_constant(): + assert use_staticcall, "typechecker missed this" + + if use_staticcall: + call_op = ["staticcall", gas, contract_address, args_ofst, args_len, buf, ret_len] + else: + call_op = ["call", gas, contract_address, value, args_ofst, args_len, buf, ret_len] + + ret.append(check_external_call(call_op)) + + return_t = None + if fn_type.return_type is not None: + return_t = new_type_to_old_type(fn_type.return_type) + ret.append(ret_unpacker) + + return IRnode.from_list(ret, typ=return_t, location=MEMORY) diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index a970efc3f3..14d86aadbd 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -46,6 +46,7 @@ def __init__(self, node: vy_ast.VyperNode, context: Context) -> None: self.ir_node.source_pos = getpos(self.stmt) def parse_Expr(self): + # TODO: follow analysis modules and dispatch down to expr.py return Stmt(self.stmt.value, self.context).ir_node def parse_Pass(self): @@ -128,6 +129,8 @@ def parse_Log(self): return events.ir_node_for_log(self.stmt, event, topic_ir, data_ir, self.context) def parse_Call(self): + # TODO use expr.func.type.is_internal once type annotations + # are consistently available. is_self_function = ( (isinstance(self.stmt.func, vy_ast.Attribute)) and isinstance(self.stmt.func.value, vy_ast.Name) @@ -142,6 +145,7 @@ def parse_Call(self): "append", "pop", ): + # TODO: consider moving this to builtins darray = Expr(self.stmt.func.value, self.context).ir_node args = [Expr(x, self.context).ir_node for x in self.stmt.args] if self.stmt.func.attr == "append": diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 0bce57c856..5323f48b77 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -311,7 +311,7 @@ def from_FunctionDef( namespace = get_namespace() for arg, value in zip(node.args.args, defaults): - if arg.arg in ("gas", "value", "skip_contract_check"): + if arg.arg in ("gas", "value", "skip_contract_check", "default_return_value"): raise ArgumentException( f"Cannot use '{arg.arg}' as a variable name in a function input", arg, @@ -463,7 +463,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[BaseTypeDefinition]: # for external calls, include gas and value as optional kwargs kwarg_keys = self.kwarg_keys.copy() if node.get("func.value.id") != "self": - kwarg_keys += ["gas", "value", "skip_contract_check"] + kwarg_keys += ["gas", "value", "skip_contract_check", "default_return_value"] validate_call_args(node, (self.min_arg_count, self.max_arg_count), kwarg_keys) if self.mutability < StateMutability.PAYABLE: @@ -477,10 +477,12 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[BaseTypeDefinition]: for kwarg in node.keywords: if kwarg.arg in ("gas", "value"): validate_expected_type(kwarg.value, Uint256Definition()) - elif kwarg.arg in ("skip_contract_check"): + elif kwarg.arg in ("skip_contract_check",): validate_expected_type(kwarg.value, BoolDefinition()) if not isinstance(kwarg.value, vy_ast.NameConstant): raise InvalidType("skip_contract_check must be literal bool", kwarg.value) + elif kwarg.arg in ("default_return_value",): + validate_expected_type(kwarg.value, self.return_type) else: # Generate the modified source code string with the kwarg removed # as a suggestion to the user.