diff --git a/tests/parser/functions/test_raw_call.py b/tests/parser/functions/test_raw_call.py index ffa3c3f33a..b97d51e1c4 100644 --- a/tests/parser/functions/test_raw_call.py +++ b/tests/parser/functions/test_raw_call.py @@ -1,5 +1,6 @@ import pytest +from hexbytes import HexBytes from vyper import compiler from vyper.exceptions import ArgumentException, StateAccessViolation from vyper.functions import get_create_forwarder_to_bytecode @@ -63,15 +64,15 @@ def create_and_return_forwarder(inp: address) -> address: assert c2.create_and_call_returnten(c.address) == 10 c2.create_and_call_returnten(c.address, transact={}) - expected_forwarder_code_mask = get_create_forwarder_to_bytecode()[12:] + _, preamble, callcode = get_create_forwarder_to_bytecode() c3 = c2.create_and_return_forwarder(c.address, call={}) c2.create_and_return_forwarder(c.address, transact={}) c3_contract_code = w3.toBytes(w3.eth.getCode(c3)) - assert c3_contract_code[:14] == expected_forwarder_code_mask[:14] - assert c3_contract_code[35:] == expected_forwarder_code_mask[35:] + assert c3_contract_code[:10] == HexBytes(preamble) + assert c3_contract_code[-15:] == HexBytes(callcode) print("Passed forwarder test") # TODO: This one is special diff --git a/vyper/functions/functions.py b/vyper/functions/functions.py index 22528e0ac0..7bcc379906 100644 --- a/vyper/functions/functions.py +++ b/vyper/functions/functions.py @@ -1426,54 +1426,53 @@ def build_LLL(self, expr, context): def get_create_forwarder_to_bytecode(): - from vyper.compile_lll import assembly_to_evm, num_to_bytearray + from vyper.compile_lll import assembly_to_evm - code_a = [ + loader_asm = [ "PUSH1", - 0x33, + 0x2D, + "RETURNDATASIZE", + "DUP2", "PUSH1", - 0x0C, - "PUSH1", - 0x00, + 0x09, + "RETURNDATASIZE", "CODECOPY", - "PUSH1", - 0x33, - "PUSH1", - 0x00, "RETURN", + ] + forwarder_pre_asm = [ "CALLDATASIZE", - "PUSH1", - 0x00, - "PUSH1", - 0x00, + "RETURNDATASIZE", + "RETURNDATASIZE", "CALLDATACOPY", - "PUSH2", - num_to_bytearray(0x1000), - "PUSH1", - 0x00, + "RETURNDATASIZE", + "RETURNDATASIZE", + "RETURNDATASIZE", "CALLDATASIZE", - "PUSH1", - 0x00, + "RETURNDATASIZE", "PUSH20", # [address to delegate to] ] - code_b = [ + forwarder_post_asm = [ "GAS", "DELEGATECALL", + "RETURNDATASIZE", + "DUP3", + "DUP1", + "RETURNDATACOPY", + "SWAP1", + "RETURNDATASIZE", + "SWAP2", "PUSH1", - 0x2C, # jumpdest of whole program. + 0x2B, # jumpdest of whole program. "JUMPI", - "PUSH1", - 0x0, - "DUP1", "REVERT", "JUMPDEST", - "PUSH2", - num_to_bytearray(0x1000), - "PUSH1", - 0x00, "RETURN", ] - return assembly_to_evm(code_a)[0] + (b"\x00" * 20) + assembly_to_evm(code_b)[0] + return ( + assembly_to_evm(loader_asm)[0], + assembly_to_evm(forwarder_pre_asm)[0], + assembly_to_evm(forwarder_post_asm)[0], + ) class CreateForwarderTo(_SimpleBuiltinFunction): @@ -1492,17 +1491,28 @@ def build_LLL(self, expr, args, kwargs, context): ) placeholder = context.new_internal_variable(ByteArrayType(96)) - kode = get_create_forwarder_to_bytecode() - high = bytes_to_int(kode[:32]) - low = bytes_to_int((kode + b"\x00" * 32)[47:79]) + loader_evm, forwarder_pre_evm, forwarder_post_evm = get_create_forwarder_to_bytecode() + # Adjust to 32-byte boundaries + preamble_length = len(loader_evm) + len(forwarder_pre_evm) + forwarder_preamble = bytes_to_int( + loader_evm + forwarder_pre_evm + b"\x00" * (32 - preamble_length) + ) + forwarder_post = bytes_to_int(forwarder_post_evm + b"\x00" * (32 - len(forwarder_post_evm))) + + if args[0].typ.is_literal: + target_address = args[0].value * 2 ** 96 + elif version_check(begin="constantinople"): + target_address = ["shl", 96, args[0]] + else: + target_address = ["mul", args[0], 2 ** 96] return LLLnode.from_list( [ "seq", - ["mstore", placeholder, high], - ["mstore", ["add", placeholder, 27], ["mul", args[0], 2 ** 96]], - ["mstore", ["add", placeholder, 47], low], - ["clamp_nonzero", ["create", value, placeholder, 96]], + ["mstore", placeholder, forwarder_preamble], + ["mstore", ["add", placeholder, preamble_length], target_address], + ["mstore", ["add", placeholder, preamble_length + 20], forwarder_post], + ["create", value, placeholder, preamble_length + 20 + len(forwarder_post_evm)], ], typ=BaseType("address"), pos=getpos(expr),