diff --git a/tests/parser/features/test_clampers.py b/tests/parser/features/test_clampers.py index 7a73f15f0b..48a4c39dc9 100644 --- a/tests/parser/features/test_clampers.py +++ b/tests/parser/features/test_clampers.py @@ -1,4 +1,15 @@ -def test_clamper_test_code(assert_tx_failed, get_contract_with_gas_estimation): +import pytest +from eth_utils import keccak + + +def _make_tx(w3, address, signature, values): + # helper function to broadcast transactions that fail clamping check + sig = keccak(signature.encode()).hex()[:8] + data = "".join(int(i).to_bytes(32, "big", signed=i < 0).hex() for i in values) + w3.eth.sendTransaction({"to": address, "data": f"0x{sig}{data}"}) + + +def test_bytes_clamper(assert_tx_failed, get_contract_with_gas_estimation): clamper_test_code = """ @public def foo(s: bytes[3]) -> bytes[3]: @@ -10,4 +21,184 @@ def foo(s: bytes[3]) -> bytes[3]: assert c.foo(b"cat") == b"cat" assert_tx_failed(lambda: c.foo(b"cate")) - print("Passed bytearray clamping test") + +def test_bytes_clamper_multiple_slots(assert_tx_failed, get_contract_with_gas_estimation): + clamper_test_code = """ +@public +def foo(s: bytes[40]) -> bytes[40]: + return s + """ + + data = b"this is exactly forty characters long!!!" + c = get_contract_with_gas_estimation(clamper_test_code) + + assert c.foo(data[:30]) == data[:30] + assert c.foo(data) == data + assert_tx_failed(lambda: c.foo(data + b"!")) + + +@pytest.mark.parametrize("value", [0, 1, -1, 2 ** 127 - 1, -(2 ** 127)]) +def test_int128_clamper_passing(w3, get_contract, value): + code = """ +@public +def foo(s: int128) -> int128: + return s + """ + + c = get_contract(code) + _make_tx(w3, c.address, "foo(int128)", [value]) + + +@pytest.mark.parametrize("value", [2 ** 127, -(2 ** 127) - 1, 2 ** 255 - 1, -(2 ** 255)]) +def test_int128_clamper_failing(w3, assert_tx_failed, get_contract, value): + code = """ +@public +def foo(s: int128) -> int128: + return s + """ + + c = get_contract(code) + assert_tx_failed(lambda: _make_tx(w3, c.address, "foo(int128)", [value])) + + +@pytest.mark.parametrize("value", [0, 1]) +def test_bool_clamper_passing(w3, get_contract, value): + code = """ +@public +def foo(s: bool) -> bool: + return s + """ + + c = get_contract(code) + _make_tx(w3, c.address, "foo(bool)", [value]) + + +@pytest.mark.parametrize("value", [2, 3, 4, 8, 16, 2 ** 256 - 1]) +def test_bool_clamper_failing(w3, assert_tx_failed, get_contract, value): + code = """ +@public +def foo(s: bool) -> bool: + return s + """ + + c = get_contract(code) + assert_tx_failed(lambda: _make_tx(w3, c.address, "foo(bool)", [value])) + + +@pytest.mark.parametrize("value", [0, 1, 2 ** 160 - 1]) +def test_address_clamper_passing(w3, get_contract, value): + code = """ +@public +def foo(s: address) -> address: + return s + """ + + c = get_contract(code) + _make_tx(w3, c.address, "foo(address)", [value]) + + +@pytest.mark.parametrize("value", [2 ** 160, 2 ** 256 - 1]) +def test_address_clamper_failing(w3, assert_tx_failed, get_contract, value): + code = """ +@public +def foo(s: address) -> address: + return s + """ + + c = get_contract(code) + assert_tx_failed(lambda: _make_tx(w3, c.address, "foo(address)", [value])) + + +@pytest.mark.parametrize("value", [0, 1, -1, 2 ** 127 - 1, -(2 ** 127)]) +def test_int128_array_clamper_passing(w3, get_contract, value): + code = """ +@public +def foo(a: uint256, b: int128[5], c: uint256) -> int128[5]: + return b + """ + + # on both ends of the array we place a `uint256` that would fail the clamp check, + # to ensure there are no off-by-one errors + values = [2 ** 127] + ([value] * 5) + [2 ** 127] + + c = get_contract(code) + _make_tx(w3, c.address, "foo(uint256,int128[5],uint256)", values) + + +@pytest.mark.parametrize("bad_value", [2 ** 127, -(2 ** 127) - 1, 2 ** 255 - 1, -(2 ** 255)]) +@pytest.mark.parametrize("idx", range(5)) +def test_int128_array_clamper_failing(w3, assert_tx_failed, get_contract, bad_value, idx): + # ensure the invalid value is detected at all locations in the array + code = """ +@public +def foo(b: int128[5]) -> int128[5]: + return b + """ + + values = [0] * 5 + values[idx] = bad_value + + c = get_contract(code) + assert_tx_failed(lambda: _make_tx(w3, c.address, "foo(int128[5])", values)) + + +@pytest.mark.parametrize("value", [0, 1, -1, 2 ** 127 - 1, -(2 ** 127)]) +def test_int128_array_looped_clamper_passing(w3, get_contract, value): + # when an array is > 5 items, the arg clamper runs in a loop to reduce bytecode size + code = """ +@public +def foo(a: uint256, b: int128[10], c: uint256) -> int128[10]: + return b + """ + + values = [2 ** 127] + ([value] * 10) + [2 ** 127] + + c = get_contract(code) + _make_tx(w3, c.address, "foo(uint256,int128[10],uint256)", values) + + +@pytest.mark.parametrize("bad_value", [2 ** 127, -(2 ** 127) - 1, 2 ** 255 - 1, -(2 ** 255)]) +@pytest.mark.parametrize("idx", range(10)) +def test_int128_array_looped_clamper_failing(w3, assert_tx_failed, get_contract, bad_value, idx): + code = """ +@public +def foo(b: int128[10]) -> int128[10]: + return b + """ + + values = [0] * 10 + values[idx] = bad_value + + c = get_contract(code) + assert_tx_failed(lambda: _make_tx(w3, c.address, "foo(int128[10])", values)) + + +@pytest.mark.parametrize("value", [0, 1, -1, 2 ** 127 - 1, -(2 ** 127)]) +def test_multidimension_array_clamper_passing(w3, get_contract, value): + code = """ +@public +def foo(a: uint256, b: int128[6][3][1][8], c: uint256) -> int128[6][3][1][8]: + return b + """ + + # 6 * 3 * 1 * 8 = 144, the total number of values in our multidimensional array + values = [2 ** 127] + ([value] * 144) + [2 ** 127] + + c = get_contract(code) + _make_tx(w3, c.address, "foo(uint256,int128[6][3][1][8],uint256)", values) + + +@pytest.mark.parametrize("bad_value", [2 ** 127, -(2 ** 127) - 1, 2 ** 255 - 1, -(2 ** 255)]) +@pytest.mark.parametrize("idx", range(12)) +def test_multidimension_array_clamper_failing(w3, assert_tx_failed, get_contract, bad_value, idx): + code = """ +@public +def foo(b: int128[6][1][2]) -> int128[6][1][2]: + return b + """ + + values = [0] * 12 + values[idx] = bad_value + + c = get_contract(code) + assert_tx_failed(lambda: _make_tx(w3, c.address, "foo(int128[6][1][2]])", values)) diff --git a/vyper/parser/arg_clamps.py b/vyper/parser/arg_clamps.py index c97c5e162f..585b20d946 100644 --- a/vyper/parser/arg_clamps.py +++ b/vyper/parser/arg_clamps.py @@ -22,6 +22,22 @@ def _mk_codecopy_copier(pos, sz, mempos): def make_arg_clamper(datapos, mempos, typ, is_init=False): """ Clamps argument to type limits. + + Arguments + --------- + datapos : int | LLLnode + Calldata offset of the value being clamped + mempos : int | LLLnode + Memory offset that the value is stored at during clamping + typ : vyper.types.types.BaseType + Type of the value + is_init : bool, optional + Boolean indicating if we are generating init bytecode + + Returns + ------- + LLLnode + Arg clamper LLL """ if not is_init: @@ -68,31 +84,45 @@ def make_arg_clamper(datapos, mempos, typ, is_init=False): # Lists: recurse elif isinstance(typ, ListType): if typ.count > 5 or (type(datapos) is list and type(mempos) is list): - subtype_size = get_size_of_type(typ.subtype) - i_incr = subtype_size * 32 + # find ultimate base type + subtype = typ.subtype + while hasattr(subtype, "subtype"): + subtype = subtype.subtype + + # make arg clamper for the base type + offset = MemoryPositions.FREE_LOOP_INDEX + clamper = make_arg_clamper( + ["add", datapos, ["mload", offset]], + ["add", mempos, ["mload", offset]], + subtype, + is_init, + ) + if clamper.value == "pass": + # no point looping if the base type doesn't require clamping + return clamper + + # loop the entire array at once, even if it's multidimensional + type_size = get_size_of_type(typ) + i_incr = get_size_of_type(subtype) * 32 - mem_to = subtype_size * 32 * (typ.count - 1) + mem_to = type_size * 32 loop_label = f"_check_list_loop_{str(uuid.uuid4())}" - offset = 288 - o = [ + lll_node = [ ["mstore", offset, 0], # init loop ["label", loop_label], - make_arg_clamper( - ["add", datapos, ["mload", offset]], - ["add", mempos, ["mload", offset]], - typ.subtype, - is_init, - ), + clamper, ["mstore", offset, ["add", ["mload", offset], i_incr]], ["if", ["lt", ["mload", offset], mem_to], ["goto", loop_label]], ] else: - o = [] + lll_node = [] for i in range(typ.count): offset = get_size_of_type(typ.subtype) * 32 * i - o.append(make_arg_clamper(datapos + offset, mempos + offset, typ.subtype, is_init)) - return LLLnode.from_list(["seq"] + o, typ=None, annotation="checking list input") + lll_node.append( + make_arg_clamper(datapos + offset, mempos + offset, typ.subtype, is_init) + ) + return LLLnode.from_list(["seq"] + lll_node, typ=None, annotation="checking list input") # Otherwise don't make any checks else: return LLLnode.from_list("pass")