From daeae11154e58bb255fcad50c91773b6abc1fbb8 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 17 Jun 2022 14:01:29 -0400 Subject: [PATCH] feat: arithmetic for new int types (#2843) this commit generalizes codegen for arithmetic ops to handle any integer type. to do so, it refactors codegen for arithmetic ops into its own module. it also slightly changes the codegen so that it targets optimizer rules better. this commit also changes DecimalInfo.divisor to be of `int` type. this is slightly less convenient, but it's clearer when used that it is not getting truncated. also, fix a pytest-split setting. the (default) `duration_based_chunks` splitting algorithm in pytest-split resulted in a 0-size chunk getting allocated to the final group, which caused tox to error out with InvocationError code 5. --- .github/workflows/test.yml | 4 +- tests/ast/nodes/test_evaluate_compare.py | 23 +- tests/compiler/ir/test_optimize_ir.py | 11 + tests/fuzzing/test_exponents.py | 2 +- tests/parser/types/numbers/test_decimals.py | 63 ++- tests/parser/types/numbers/test_int128.py | 212 ---------- tests/parser/types/numbers/test_int256.py | 316 --------------- .../parser/types/numbers/test_signed_ints.py | 329 ++++++++++++++++ tests/parser/types/numbers/test_uint256.py | 194 ---------- tests/parser/types/numbers/test_uint8.py | 206 ---------- .../types/numbers/test_unsigned_ints.py | 225 +++++++++++ vyper/builtin_functions/convert.py | 12 +- vyper/codegen/arithmetic.py | 352 +++++++++++++++++ vyper/codegen/core.py | 22 +- vyper/codegen/expr.py | 366 +----------------- vyper/codegen/types/types.py | 9 +- vyper/ir/optimizer.py | 33 +- vyper/utils.py | 1 + 18 files changed, 1069 insertions(+), 1311 deletions(-) delete mode 100644 tests/parser/types/numbers/test_int128.py delete mode 100644 tests/parser/types/numbers/test_int256.py create mode 100644 tests/parser/types/numbers/test_signed_ints.py delete mode 100644 tests/parser/types/numbers/test_uint256.py delete mode 100644 tests/parser/types/numbers/test_uint8.py create mode 100644 tests/parser/types/numbers/test_unsigned_ints.py create mode 100644 vyper/codegen/arithmetic.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9f36e850df..c69d434059 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -131,10 +131,10 @@ jobs: # NOTE: if the tests get poorly distributed, run this and commit the resulting `.test_durations` file to the `vyper-test-durations` repo. # `TOXENV=fuzzing tox -r -- --store-durations --reruns 10 --reruns-delay 1 -r aR tests/` - name: Fetch test-durations - run: curl --location "https://raw.githubusercontent.com/vyperlang/vyper-test-durations/54ee7f6a09bd94192d01f8de5293483414295e45/test_durations" -o .test_durations + run: curl --location "https://raw.githubusercontent.com/vyperlang/vyper-test-durations/4d8398e581f183de986892c5a8d4ab3d05ccaab2/test_durations" -o .test_durations - name: Run Tox - run: TOXENV=fuzzing tox -r -- --splits 45 --group ${{ matrix.group }} --reruns 10 --reruns-delay 1 -r aR tests/ + run: TOXENV=fuzzing tox -r -- --splits 45 --group ${{ matrix.group }} --splitting-algorithm least_duration --reruns 10 --reruns-delay 1 -r aR tests/ - name: Upload Coverage uses: codecov/codecov-action@v1 diff --git a/tests/ast/nodes/test_evaluate_compare.py b/tests/ast/nodes/test_evaluate_compare.py index db07da0b13..8761cccac4 100644 --- a/tests/ast/nodes/test_evaluate_compare.py +++ b/tests/ast/nodes/test_evaluate_compare.py @@ -6,11 +6,12 @@ from vyper.exceptions import UnfoldableNode +# TODO expand to all signed types @pytest.mark.fuzzing @settings(max_examples=50, deadline=1000) @given(left=st.integers(), right=st.integers()) @pytest.mark.parametrize("op", ["==", "!=", "<", "<=", ">=", ">"]) -def test_compare_eq(get_contract, op, left, right): +def test_compare_eq_signed(get_contract, op, left, right): source = f""" @external def foo(a: int128, b: int128) -> bool: @@ -25,6 +26,26 @@ def foo(a: int128, b: int128) -> bool: assert contract.foo(left, right) == new_node.value +# TODO expand to all unsigned types +@pytest.mark.fuzzing +@settings(max_examples=50, deadline=1000) +@given(left=st.integers(min_value=0), right=st.integers(min_value=0)) +@pytest.mark.parametrize("op", ["==", "!=", "<", "<=", ">=", ">"]) +def test_compare_eq_unsigned(get_contract, op, left, right): + source = f""" +@external +def foo(a: uint128, b: uint128) -> bool: + return a {op} b + """ + contract = get_contract(source) + + vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") + old_node = vyper_ast.body[0].value + new_node = old_node.evaluate() + + assert contract.foo(left, right) == new_node.value + + @pytest.mark.fuzzing @settings(max_examples=20, deadline=500) @given(left=st.integers(), right=st.lists(st.integers(), min_size=1, max_size=16)) diff --git a/tests/compiler/ir/test_optimize_ir.py b/tests/compiler/ir/test_optimize_ir.py index 8d36aa0b7c..f283588750 100644 --- a/tests/compiler/ir/test_optimize_ir.py +++ b/tests/compiler/ir/test_optimize_ir.py @@ -82,6 +82,13 @@ (["mod", "x", 128], ["and", "x", 127]), (["sdiv", "x", 64], None), (["smod", "x", 64], None), + (["exp", 3, 5], [3 ** 5]), + (["exp", 3, 256], [(3 ** 256) % (2 ** 256)]), + (["exp", 2, 257], [0]), + (["exp", "x", 0], [1]), + (["exp", "x", 1], ["x"]), + (["exp", 1, "x"], [1]), + (["exp", 0, "x"], ["iszero", "x"]), # bitwise ops (["shr", 0, "x"], ["x"]), (["sar", 0, "x"], ["x"]), @@ -102,6 +109,7 @@ (["and", "x", 1], None), (["or", "x", 1], None), (["xor", 0, "x"], ["x"]), + (["xor", "x", "x"], [0]), (["iszero", ["or", "x", 1]], [0]), (["iszero", ["or", 2, "x"]], [0]), (["iszero", ["or", 1, ["sload", 0]]], None), @@ -113,6 +121,9 @@ (["eq", -1, ["add", 2 ** 255, 2 ** 255 - 1]], [1]), # test compile-time wrapping (["eq", -1, ["add", -(2 ** 255), 2 ** 255 - 1]], [1]), # test compile-time wrapping (["eq", -2, ["add", 2 ** 256 - 1, 2 ** 256 - 1]], [1]), # test compile-time wrapping + (["eq", "x", "x"], [1]), + (["eq", "callvalue", "callvalue"], None), + (["ne", "x", "x"], [0]), ] diff --git a/tests/fuzzing/test_exponents.py b/tests/fuzzing/test_exponents.py index 0ac1285b56..e937f023e4 100644 --- a/tests/fuzzing/test_exponents.py +++ b/tests/fuzzing/test_exponents.py @@ -2,7 +2,7 @@ from hypothesis import example, given, settings from hypothesis import strategies as st -from vyper.codegen.expr import calculate_largest_base, calculate_largest_power +from vyper.codegen.arithmetic import calculate_largest_base, calculate_largest_power @pytest.mark.fuzzing diff --git a/tests/parser/types/numbers/test_decimals.py b/tests/parser/types/numbers/test_decimals.py index b920a53160..46c38420c9 100644 --- a/tests/parser/types/numbers/test_decimals.py +++ b/tests/parser/types/numbers/test_decimals.py @@ -1,8 +1,9 @@ -from decimal import Decimal, getcontext +from decimal import ROUND_DOWN, Decimal, getcontext import pytest from vyper.exceptions import DecimalOverrideException, TypeMismatch +from vyper.utils import DECIMAL_EPSILON, SizeLimits def test_decimal_override(): @@ -11,6 +12,10 @@ def test_decimal_override(): getcontext().prec = 100 +def quantize(x: Decimal) -> Decimal: + return x.quantize(DECIMAL_EPSILON, rounding=ROUND_DOWN) + + def test_decimal_test(get_contract_with_gas_estimation): decimal_test = """ @external @@ -131,10 +136,60 @@ def _num_mul(x: decimal, y: decimal) -> decimal: c = get_contract_with_gas_estimation(mul_code) - NUM_1 = Decimal("85070591730234615865843651857942052864") - NUM_2 = Decimal("136112946768375385385349842973") + x = Decimal("85070591730234615865843651857942052864") + y = Decimal("136112946768375385385349842973") + + assert_tx_failed(lambda: c._num_mul(x, y)) + + x = SizeLimits.MAX_AST_DECIMAL + y = 1 + DECIMAL_EPSILON + + assert_tx_failed(lambda: c._num_mul(x, y)) + + assert c._num_mul(x, Decimal(1)) == x + + assert c._num_mul(x, 1 - DECIMAL_EPSILON) == quantize(x * (1 - DECIMAL_EPSILON)) + + x = SizeLimits.MIN_AST_DECIMAL + assert c._num_mul(x, 1 - DECIMAL_EPSILON) == quantize(x * (1 - DECIMAL_EPSILON)) + + +# division failure modes(!) +def test_div_overflow(get_contract, assert_tx_failed): + code = """ +@external +def foo(x: decimal, y: decimal) -> decimal: + return x / y + """ + + c = get_contract(code) + + x = SizeLimits.MIN_AST_DECIMAL + y = -DECIMAL_EPSILON + + assert_tx_failed(lambda: c.foo(x, y)) + assert_tx_failed(lambda: c.foo(x, Decimal(0))) + assert_tx_failed(lambda: c.foo(y, Decimal(0))) + + y = Decimal(1) - DECIMAL_EPSILON # 0.999999999 + assert_tx_failed(lambda: c.foo(x, y)) + + y = Decimal(-1) + assert_tx_failed(lambda: c.foo(x, y)) + + assert c.foo(x, Decimal(1)) == x + assert c.foo(x, 1 + DECIMAL_EPSILON) == quantize(x / (1 + DECIMAL_EPSILON)) + + x = SizeLimits.MAX_AST_DECIMAL + + assert_tx_failed(lambda: c.foo(x, DECIMAL_EPSILON)) + + y = Decimal(1) - DECIMAL_EPSILON + assert_tx_failed(lambda: c.foo(x, y)) + + assert c.foo(x, Decimal(1)) == x - assert_tx_failed(lambda: c._num_mul(NUM_1, NUM_2)) + assert c.foo(x, 1 + DECIMAL_EPSILON) == quantize(x / (1 + DECIMAL_EPSILON)) def test_decimal_min_max_literals(assert_tx_failed, get_contract_with_gas_estimation): diff --git a/tests/parser/types/numbers/test_int128.py b/tests/parser/types/numbers/test_int128.py deleted file mode 100644 index 2fe217f63d..0000000000 --- a/tests/parser/types/numbers/test_int128.py +++ /dev/null @@ -1,212 +0,0 @@ -from decimal import Decimal - -from vyper.exceptions import OverflowException - - -def test_exponent_base_zero(get_contract): - code = """ -@external -def foo(x: int128) -> int128: - return 0 ** x - """ - c = get_contract(code) - assert c.foo(0) == 1 - assert c.foo(1) == 0 - assert c.foo(-1) == 0 - assert c.foo(2 ** 127 - 1) == 0 - assert c.foo(-(2 ** 127)) == 0 - - -def test_exponent_base_one(get_contract): - code = """ -@external -def foo(x: int128) -> int128: - return 1 ** x - """ - c = get_contract(code) - assert c.foo(0) == 1 - assert c.foo(1) == 1 - assert c.foo(-1) == 1 - assert c.foo(2 ** 127 - 1) == 1 - assert c.foo(-(2 ** 127)) == 1 - - -def test_num_divided_by_num(get_contract_with_gas_estimation): - code = """ -@external -def foo(inp: int128) -> int128: - y: int128 = 5/inp - return y -""" - c = get_contract_with_gas_estimation(code) - assert c.foo(2) == 2 - assert c.foo(5) == 1 - assert c.foo(10) == 0 - assert c.foo(50) == 0 - - -def test_decimal_divided_by_num(get_contract_with_gas_estimation): - code = """ -@external -def foo(inp: decimal) -> decimal: - y: decimal = inp/5.0 - return y -""" - c = get_contract_with_gas_estimation(code) - assert c.foo(Decimal("1")) == Decimal("0.2") - assert c.foo(Decimal(".5")) == Decimal("0.1") - assert c.foo(Decimal(".2")) == Decimal(".04") - - -def test_negative_nums(get_contract_with_gas_estimation): - negative_nums_code = """ -@external -def _negative_num() -> int128: - return -1 - -@external -def _negative_exp() -> int128: - return -(1+2) - -@external -def _negative_exp_var() -> int128: - a: int128 = 2 - return -(a+2) - """ - - c = get_contract_with_gas_estimation(negative_nums_code) - assert c._negative_num() == -1 - assert c._negative_exp() == -3 - assert c._negative_exp_var() == -4 - - -def test_num_bound(assert_tx_failed, get_contract_with_gas_estimation): - num_bound_code = """ -@external -def _num(x: int128) -> int128: - return x - -@external -def _num_add(x: int128, y: int128) -> int128: - return x + y - -@external -def _num_sub(x: int128, y: int128) -> int128: - return x - y - -@external -def _num_add3(x: int128, y: int128, z: int128) -> int128: - return x + y + z - -@external -def _num_max() -> int128: - return 170141183460469231731687303715884105727 # 2**127 - 1 - -@external -def _num_min() -> int128: - return -170141183460469231731687303715884105728 # -2**127 - """ - - c = get_contract_with_gas_estimation(num_bound_code) - - NUM_MAX = 2 ** 127 - 1 - NUM_MIN = -(2 ** 127) - assert c._num_add(NUM_MAX, 0) == NUM_MAX - assert c._num_sub(NUM_MIN, 0) == NUM_MIN - assert c._num_add(NUM_MAX - 1, 1) == NUM_MAX - assert c._num_sub(NUM_MIN + 1, 1) == NUM_MIN - assert_tx_failed(lambda: c._num_add(NUM_MAX, 1)) - assert_tx_failed(lambda: c._num_sub(NUM_MIN, 1)) - assert_tx_failed(lambda: c._num_add(NUM_MAX - 1, 2)) - assert_tx_failed(lambda: c._num_sub(NUM_MIN + 1, 2)) - assert c._num_max() == NUM_MAX - assert c._num_min() == NUM_MIN - - assert_tx_failed(lambda: c._num_add3(NUM_MAX, 1, -1)) - assert c._num_add3(NUM_MAX, -1, 1) == NUM_MAX - - -def test_overflow_out_of_range(get_contract, assert_compile_failed): - code = """ -@external -def num_sub() -> int128: - return 1-2**256 - """ - - assert_compile_failed(lambda: get_contract(code), OverflowException) - - -def test_overflow_add(get_contract, assert_tx_failed): - code = """ -@external -def num_add(i: int128) -> int128: - return (2**127-1) + i - """ - c = get_contract(code) - - assert c.num_add(0) == 2 ** 127 - 1 - assert c.num_add(-1) == 2 ** 127 - 2 - - assert_tx_failed(lambda: c.num_add(1)) - assert_tx_failed(lambda: c.num_add(2)) - - -def test_overflow_add_vars(get_contract, assert_tx_failed): - code = """ -@external -def num_add(a: int128, b: int128) -> int128: - return a + b - """ - c = get_contract(code) - - assert_tx_failed(lambda: c.num_add(2 ** 127 - 1, 1)) - assert_tx_failed(lambda: c.num_add(1, 2 ** 127 - 1)) - - -def test_overflow_sub_vars(get_contract, assert_tx_failed): - code = """ -@external -def num_sub(a: int128, b: int128) -> int128: - return a - b - """ - - c = get_contract(code) - - assert c.num_sub(-(2 ** 127), -1) == (-(2 ** 127)) + 1 - assert_tx_failed(lambda: c.num_sub(-(2 ** 127), 1)) - - -def test_overflow_mul_vars(get_contract, assert_tx_failed): - code = """ -@external -def num_mul(a: int128, b: int128) -> int128: - return a * b - """ - - c = get_contract(code) - - assert c.num_mul(-(2 ** 127), 1) == -(2 ** 127) - assert_tx_failed(lambda: c.num_mul(2 ** 126, 2)) - - -def test_literal_int_division(get_contract): - code = """ -@external -def foo() -> int128: - z: int128 = 5 / 2 - return z - """ - - c = get_contract(code) - - assert c.foo() == 2 - - -def test_literal_int_division_return(get_contract, assert_compile_failed): - code = """ -@external -def test() -> decimal: - return 5 / 2 - """ - - assert_compile_failed(lambda: get_contract(code)) diff --git a/tests/parser/types/numbers/test_int256.py b/tests/parser/types/numbers/test_int256.py deleted file mode 100644 index 7689afc214..0000000000 --- a/tests/parser/types/numbers/test_int256.py +++ /dev/null @@ -1,316 +0,0 @@ -from vyper.exceptions import OverflowException - - -def test_exponent_base_zero(get_contract): - code = """ -@external -def foo(x: int256) -> int256: - return 0 ** x - """ - c = get_contract(code) - assert c.foo(0) == 1 - assert c.foo(1) == 0 - assert c.foo(-1) == 0 - assert c.foo(2 ** 255 - 1) == 0 - assert c.foo(-(2 ** 255)) == 0 - - -def test_exponent_base_one(get_contract): - code = """ -@external -def foo(x: int256) -> int256: - return 1 ** x - """ - c = get_contract(code) - assert c.foo(0) == 1 - assert c.foo(1) == 1 - assert c.foo(-1) == 1 - assert c.foo(2 ** 255 - 1) == 1 - assert c.foo(-(2 ** 255)) == 1 - - -def test_exponent(get_contract, assert_tx_failed): - code = """ -@external -def foo(x: int256) -> int256: - return 4 ** x - """ - c = get_contract(code) - assert c.foo(0) == 1 - assert c.foo(1) == 4 - assert c.foo(4) == 256 - assert c.foo(127) == 4 ** 127 - assert_tx_failed(lambda: c.foo(128)) - assert_tx_failed(lambda: c.foo(-1)) - assert_tx_failed(lambda: c.foo(-(2 ** 255))) - - -def test_num_divided_by_num(get_contract_with_gas_estimation): - code = """ -@external -def foo(inp: int256) -> int256: - y: int256 = 5/inp - return y -""" - c = get_contract_with_gas_estimation(code) - assert c.foo(2) == 2 - assert c.foo(5) == 1 - assert c.foo(10) == 0 - assert c.foo(50) == 0 - - -def test_negative_nums(get_contract_with_gas_estimation): - negative_nums_code = """ -@external -def _negative_num() -> int256: - return -1 - -@external -def _negative_exp() -> int256: - return -(1+2) - -@external -def _negative_exp_var() -> int256: - a: int256 = 2 - return -(a+2) - """ - - c = get_contract_with_gas_estimation(negative_nums_code) - assert c._negative_num() == -1 - assert c._negative_exp() == -3 - assert c._negative_exp_var() == -4 - - -def test_num_bound(assert_tx_failed, get_contract_with_gas_estimation): - num_bound_code = """ -@external -def _num(x: int256) -> int256: - return x - -@external -def _num_add(x: int256, y: int256) -> int256: - return x + y - -@external -def _num_sub(x: int256, y: int256) -> int256: - return x - y - -@external -def _num_add3(x: int256, y: int256, z: int256) -> int256: - return x + y + z - -@external -def _num_max() -> int256: - return 2 ** 255 -1 - -@external -def _num_min() -> int256: - return -2**255 - """ - - c = get_contract_with_gas_estimation(num_bound_code) - - NUM_MAX = 2 ** 255 - 1 - NUM_MIN = -(2 ** 255) - assert c._num_add(NUM_MAX, 0) == NUM_MAX - assert c._num_sub(NUM_MIN, 0) == NUM_MIN - assert c._num_add(NUM_MAX - 1, 1) == NUM_MAX - assert c._num_sub(NUM_MIN + 1, 1) == NUM_MIN - assert_tx_failed(lambda: c._num_add(NUM_MAX, 1)) - assert_tx_failed(lambda: c._num_sub(NUM_MIN, 1)) - assert_tx_failed(lambda: c._num_add(NUM_MAX - 1, 2)) - assert_tx_failed(lambda: c._num_sub(NUM_MIN + 1, 2)) - assert c._num_max() == NUM_MAX - assert c._num_min() == NUM_MIN - - assert_tx_failed(lambda: c._num_add3(NUM_MAX, 1, -1)) - assert c._num_add3(NUM_MAX, -1, 1) == NUM_MAX - - -def test_overflow_out_of_range(get_contract, assert_compile_failed): - code = """ -@external -def num_sub() -> int256: - return 1-2**256 - """ - - assert_compile_failed(lambda: get_contract(code), OverflowException) - - -def test_overflow_add(get_contract, assert_tx_failed): - code = """ -@external -def num_add(i: int256) -> int256: - return (2**255-1) + i - """ - c = get_contract(code) - - assert c.num_add(0) == 2 ** 255 - 1 - assert c.num_add(-1) == 2 ** 255 - 2 - - assert_tx_failed(lambda: c.num_add(1)) - assert_tx_failed(lambda: c.num_add(2)) - - -def test_overflow_add_vars(get_contract, assert_tx_failed): - code = """ -@external -def num_add(a: int256, b: int256) -> int256: - return a + b - """ - c = get_contract(code) - - assert_tx_failed(lambda: c.num_add(2 ** 255 - 1, 1)) - assert_tx_failed(lambda: c.num_add(1, 2 ** 255 - 1)) - - -def test_overflow_sub_vars(get_contract, assert_tx_failed): - code = """ -@external -def num_sub(a: int256, b: int256) -> int256: - return a - b - """ - - c = get_contract(code) - - assert c.num_sub(-(2 ** 255), -1) == (-(2 ** 255)) + 1 - assert_tx_failed(lambda: c.num_sub(-(2 ** 255), 1)) - - -def test_overflow_mul_vars(get_contract, assert_tx_failed): - code = """ -@external -def num_mul(a: int256, b: int256) -> int256: - return a * b - """ - - c = get_contract(code) - - assert c.num_mul(-(2 ** 255), 1) == -(2 ** 255) - assert c.num_mul(2 ** 255 - 1, -1) == -(2 ** 255) + 1 - assert c.num_mul(-1, 2 ** 255 - 1) == -(2 ** 255) + 1 - assert_tx_failed(lambda: c.num_mul(2 ** 254, 2)) - assert_tx_failed(lambda: c.num_mul(-(2 ** 255), -1)) - assert_tx_failed(lambda: c.num_mul(-1, -(2 ** 255))) - - -def test_overflow_mul_left_literal(get_contract, assert_tx_failed): - code = """ -@external -def num_mul(b: int256) -> int256: - return -1 * b - """ - - c = get_contract(code) - - assert c.num_mul(2 ** 255 - 1) == -(2 ** 255) + 1 - assert c.num_mul(-(2 ** 255) + 1) == 2 ** 255 - 1 - assert_tx_failed(lambda: c.num_mul(-(2 ** 255))) - - -def test_overflow_mul_right_literal(get_contract, assert_tx_failed): - code = """ -@external -def num_mul(a: int256) -> int256: - return a * -2**255 - """ - - c = get_contract(code) - - assert c.num_mul(1) == -(2 ** 255) - assert_tx_failed(lambda: c.num_mul(-1)) - - -def test_literal_int_division(get_contract): - code = """ -@external -def foo() -> int256: - z: int256 = 5 / 2 - return z - """ - - c = get_contract(code) - - assert c.foo() == 2 - - -def test_overflow_division(get_contract, assert_tx_failed): - code = """ -@external -def foo(a: int256, b: int256) -> int256: - return a / b - """ - - c = get_contract(code) - - assert c.foo(2 ** 255 - 1, -1) == -(2 ** 255) + 1 - assert c.foo(-(2 ** 255), 1) == -(2 ** 255) - assert_tx_failed(lambda: c.foo(-(2 ** 255), -1)) - - -def test_overflow_division_left_literal(get_contract, assert_tx_failed): - code = """ -@external -def foo(b: int256) -> int256: - return -2**255 / b - """ - - c = get_contract(code) - - assert c.foo(1) == -(2 ** 255) - assert_tx_failed(lambda: c.foo(-1)) - - -def test_overflow_division_right_literal(get_contract, assert_tx_failed): - code = """ -@external -def foo(a: int256) -> int256: - return a / -1 - """ - - c = get_contract(code) - - assert c.foo(2 ** 255 - 1) == -(2 ** 255) + 1 - assert_tx_failed(lambda: c.foo(-(2 ** 255))) - - -def test_negation(get_contract, assert_tx_failed): - code = """ -@external -def foo(a: int256) -> int256: - return -a - """ - - c = get_contract(code) - - assert c.foo(2 ** 255 - 1) == -(2 ** 255) + 1 - assert c.foo(-1) == 1 - assert c.foo(1) == -1 - assert c.foo(0) == 0 - assert_tx_failed(lambda: c.foo(-(2 ** 255))) - - -def test_literal_negative_int(get_contract, assert_tx_failed): - code = """ -@external -def addition(a: int256) -> int256: - return a + -1 - -@external -def subtraction(a: int256) -> int256: - return a - -1 - """ - - c = get_contract(code) - - assert c.addition(23) == 22 - assert c.subtraction(23) == 24 - - assert c.addition(-23) == -24 - assert c.subtraction(-23) == -22 - - assert c.addition(-(2 ** 255) + 1) == -(2 ** 255) - assert c.subtraction(2 ** 255 - 2) == 2 ** 255 - 1 - - assert_tx_failed(lambda: c.addition(-(2 ** 255))) - assert_tx_failed(lambda: c.subtraction(2 ** 255 - 1)) diff --git a/tests/parser/types/numbers/test_signed_ints.py b/tests/parser/types/numbers/test_signed_ints.py new file mode 100644 index 0000000000..2bfcfa9901 --- /dev/null +++ b/tests/parser/types/numbers/test_signed_ints.py @@ -0,0 +1,329 @@ +import itertools +import operator +import random + +import pytest + +from vyper.codegen.types.types import SIGNED_INTEGER_TYPES, parse_integer_typeinfo +from vyper.exceptions import InvalidType, OverflowException, ZeroDivisionException +from vyper.utils import SizeLimits, evm_div, evm_mod, int_bounds + +PARAMS = [] +for t in sorted(SIGNED_INTEGER_TYPES): + info = parse_integer_typeinfo(t) + lo, hi = int_bounds(bits=info.bits, signed=info.is_signed) + PARAMS.append((t, lo, hi, info.bits)) + + +@pytest.mark.parametrize("typ,lo,hi,bits", PARAMS) +def test_exponent_base_zero(get_contract, typ, lo, hi, bits): + code = f""" +@external +def foo(x: {typ}) -> {typ}: + return 0 ** x + """ + c = get_contract(code) + assert c.foo(0) == 1 + assert c.foo(1) == 0 + assert c.foo(-1) == 0 + + assert c.foo(lo) == 0 + assert c.foo(hi) == 0 + + +@pytest.mark.parametrize("typ,lo,hi,bits", PARAMS) +def test_exponent_base_one(get_contract, typ, lo, hi, bits): + code = f""" +@external +def foo(x: {typ}) -> {typ}: + return 1 ** x + """ + c = get_contract(code) + assert c.foo(0) == 1 + assert c.foo(1) == 1 + assert c.foo(-1) == 1 + assert c.foo(lo) == 1 + assert c.foo(hi) == 1 + + +@pytest.mark.parametrize("typ,lo,hi,bits", PARAMS) +def test_exponent(get_contract, assert_tx_failed, typ, lo, hi, bits): + code = f""" +@external +def foo(x: {typ}) -> {typ}: + return 4 ** x + """ + c = get_contract(code) + + test_cases = [0, 1, 3, 4, 126, 127, -1, lo, hi] + for x in test_cases: + if x * 2 >= bits or x < 0: # out of bounds + assert_tx_failed(lambda: c.foo(x)) + else: + assert c.foo(x) == 4 ** x + + +@pytest.mark.parametrize("typ,lo,hi,bits", PARAMS) +def test_negative_nums(get_contract_with_gas_estimation, typ, lo, hi, bits): + negative_nums_code = f""" +@external +def negative_one() -> {typ}: + return -1 + +@external +def negative_three() -> {typ}: + return -(1+2) + +@external +def negative_four() -> {typ}: + a: {typ} = 2 + return -(a+2) + """ + + c = get_contract_with_gas_estimation(negative_nums_code) + assert c.negative_one() == -1 + assert c.negative_three() == -3 + assert c.negative_four() == -4 + + +@pytest.mark.parametrize("typ,lo,hi,bits", PARAMS) +def test_num_bound(assert_tx_failed, get_contract_with_gas_estimation, typ, lo, hi, bits): + num_bound_code = f""" +@external +def _num(x: {typ}) -> {typ}: + return x + +@external +def _num_add(x: {typ}, y: {typ}) -> {typ}: + return x + y + +@external +def _num_sub(x: {typ}, y: {typ}) -> {typ}: + return x - y + +@external +def _num_add3(x: {typ}, y: {typ}, z: {typ}) -> {typ}: + return x + y + z + +@external +def _num_max() -> {typ}: + return {hi} + +@external +def _num_min() -> {typ}: + return {lo} + """ + + c = get_contract_with_gas_estimation(num_bound_code) + + assert c._num_add(hi, 0) == hi + assert c._num_sub(lo, 0) == lo + assert c._num_add(hi - 1, 1) == hi + assert c._num_sub(lo + 1, 1) == lo + assert_tx_failed(lambda: c._num_add(hi, 1)) + assert_tx_failed(lambda: c._num_sub(lo, 1)) + assert_tx_failed(lambda: c._num_add(hi - 1, 2)) + assert_tx_failed(lambda: c._num_sub(lo + 1, 2)) + assert c._num_max() == hi + assert c._num_min() == lo + + assert_tx_failed(lambda: c._num_add3(hi, 1, -1)) + assert c._num_add3(hi, -1, 1) == hi - 1 + 1 + assert_tx_failed(lambda: c._num_add3(lo, -1, 1)) + assert c._num_add3(lo, 1, -1) == lo + 1 - 1 + + +@pytest.mark.parametrize("typ,lo,hi,bits", PARAMS) +def test_overflow_out_of_range(get_contract, assert_compile_failed, typ, lo, hi, bits): + code = f""" +@external +def num_sub() -> {typ}: + return 1-2**{bits} + """ + + if bits == 256: + assert_compile_failed(lambda: get_contract(code), OverflowException) + else: + assert_compile_failed(lambda: get_contract(code), InvalidType) + + +ARITHMETIC_OPS = { + "+": operator.add, + "-": operator.sub, + "*": operator.mul, + "/": evm_div, + "%": evm_mod, +} + + +@pytest.mark.parametrize("op", sorted(ARITHMETIC_OPS.keys())) +@pytest.mark.parametrize("typ,lo,hi,bits", PARAMS) +@pytest.mark.fuzzing +def test_arithmetic_thorough( + get_contract, assert_tx_failed, assert_compile_failed, op, typ, lo, hi, bits +): + # both variables + code_1 = f""" +@external +def foo(x: {typ}, y: {typ}) -> {typ}: + return x {op} y + """ + # right is literal + code_2_template = """ +@external +def foo(x: {typ}) -> {typ}: + return x {op} {y} + """ + # left is literal + code_3_template = """ +@external +def foo(y: {typ}) -> {typ}: + return {x} {op} y + """ + # both literals + code_4_template = """ +@external +def foo() -> {typ}: + return {x} {op} {y} + """ + + fns = {"+": operator.add, "-": operator.sub, "*": operator.mul, "/": evm_div, "%": evm_mod} + fn = fns[op] + + c = get_contract(code_1) + + # TODO refactor to use fixtures + special_cases = [ + lo, + lo + 1, + lo // 2, + lo // 2 - 1, + lo // 2 + 1, + -3, + -2, + -1, + 0, + 1, + 2, + 3, + hi // 2 - 1, + hi // 2, + hi // 2 + 1, + hi - 1, + hi, + ] + xs = special_cases.copy() + ys = special_cases.copy() + + # note: (including special cases, roughly 8k cases total generated) + + NUM_CASES = 10 + # poor man's fuzzing - hypothesis doesn't make it easy + # with the parametrized strategy + xs += [random.randrange(lo, hi) for _ in range(NUM_CASES)] + ys += [random.randrange(lo, hi) for _ in range(NUM_CASES)] + + # edge cases that are tricky to reason about and MUST be tested + assert lo in xs and -1 in ys + + for (x, y) in itertools.product(xs, ys): + expected = fn(x, y) + in_bounds = SizeLimits.in_bounds(typ, expected) + + # safediv and safemod disallow divisor == 0 + div_by_zero = y == 0 and op in ("/", "%") + + ok = in_bounds and not div_by_zero + + code_2 = code_2_template.format(typ=typ, op=op, y=y) + code_3 = code_3_template.format(typ=typ, op=op, x=x) + code_4 = code_4_template.format(typ=typ, op=op, x=x, y=y) + + if ok: + assert c.foo(x, y) == expected + assert get_contract(code_2).foo(x) == expected + assert get_contract(code_3).foo(y) == expected + assert get_contract(code_4).foo() == expected + elif div_by_zero: + assert_tx_failed(lambda: c.foo(x, y)) + assert_compile_failed(lambda: get_contract(code_2), ZeroDivisionException) + assert_tx_failed(lambda: get_contract(code_3).foo(y)) + assert_compile_failed(lambda: get_contract(code_4), ZeroDivisionException) + else: + assert_tx_failed(lambda: c.foo(x, y)) + assert_tx_failed(lambda: get_contract(code_2).foo(x)) + assert_tx_failed(lambda: get_contract(code_3).foo(y)) + assert_compile_failed(lambda: get_contract(code_4), (InvalidType, OverflowException)) + + +COMPARISON_OPS = { + "==": operator.eq, + "!=": operator.ne, + ">": operator.gt, + ">=": operator.ge, + "<": operator.lt, + "<=": operator.le, +} + + +@pytest.mark.parametrize("op", sorted(COMPARISON_OPS.keys())) +@pytest.mark.parametrize("typ,lo,hi,bits", PARAMS) +@pytest.mark.fuzzing +def test_comparators(get_contract, op, typ, lo, hi, bits): + code_1 = f""" +@external +def foo(x: {typ}, y: {typ}) -> bool: + return x {op} y + """ + + fn = COMPARISON_OPS[op] + + c = get_contract(code_1) + + # note: constant folding is tested in tests/ast/folding + special_cases = [ + lo, + lo + 1, + lo // 2, + lo // 2 - 1, + lo // 2 + 1, + -3, + -2, + -1, + 0, + 1, + 2, + 3, + hi // 2 - 1, + hi // 2, + hi // 2 + 1, + hi - 1, + hi, + ] + + xs = special_cases.copy() + ys = special_cases.copy() + + for x, y in itertools.product(xs, ys): + expected = fn(x, y) + assert c.foo(x, y) is expected + + +@pytest.mark.parametrize("typ,lo,hi,bits", PARAMS) +def test_negation(get_contract, assert_tx_failed, typ, lo, hi, bits): + code = f""" +@external +def foo(a: {typ}) -> {typ}: + return -a + """ + + c = get_contract(code) + + assert c.foo(hi) == lo + 1 + assert c.foo(-1) == 1 + assert c.foo(1) == -1 + assert c.foo(0) == 0 + assert c.foo(2) == -2 + assert c.foo(-2) == 2 + + assert_tx_failed(lambda: c.foo(lo)) diff --git a/tests/parser/types/numbers/test_uint256.py b/tests/parser/types/numbers/test_uint256.py deleted file mode 100644 index 69f0b75a6d..0000000000 --- a/tests/parser/types/numbers/test_uint256.py +++ /dev/null @@ -1,194 +0,0 @@ -def test_exponent_base_zero(get_contract): - code = """ -@external -def foo(x: uint256) -> uint256: - return 0 ** x - """ - c = get_contract(code) - assert c.foo(0) == 1 - assert c.foo(1) == 0 - assert c.foo(42) == 0 - assert c.foo(2 ** 256 - 1) == 0 - - -def test_exponent_base_one(get_contract): - code = """ -@external -def foo(x: uint256) -> uint256: - return 1 ** x - """ - c = get_contract(code) - assert c.foo(0) == 1 - assert c.foo(1) == 1 - assert c.foo(42) == 1 - assert c.foo(2 ** 256 - 1) == 1 - - -def test_uint256_code(assert_tx_failed, get_contract_with_gas_estimation): - uint256_code = """ -@external -def _uint256_add(x: uint256, y: uint256) -> uint256: - return x + y - -@external -def _uint256_sub(x: uint256, y: uint256) -> uint256: - return x - y - -@external -def _uint256_mul(x: uint256, y: uint256) -> uint256: - return x * y - -@external -def _uint256_div(x: uint256, y: uint256) -> uint256: - return x / y - -@external -def _uint256_gt(x: uint256, y: uint256) -> bool: - return x > y - -@external -def _uint256_ge(x: uint256, y: uint256) -> bool: - return x >= y - -@external -def _uint256_lt(x: uint256, y: uint256) -> bool: - return x < y - -@external -def _uint256_le(x: uint256, y: uint256) -> bool: - return x <= y - """ - - c = get_contract_with_gas_estimation(uint256_code) - x = 126416208461208640982146408124 - y = 7128468721412412459 - - uint256_MAX = 2 ** 256 - 1 # Max possible uint256 value - assert c._uint256_add(x, y) == x + y - assert c._uint256_add(0, y) == y - assert c._uint256_add(y, 0) == y - assert_tx_failed(lambda: c._uint256_add(uint256_MAX, uint256_MAX)) - assert c._uint256_sub(x, y) == x - y - assert_tx_failed(lambda: c._uint256_sub(y, x)) - assert c._uint256_sub(0, 0) == 0 - assert c._uint256_sub(uint256_MAX, 0) == uint256_MAX - assert_tx_failed(lambda: c._uint256_sub(1, 2)) - assert c._uint256_sub(uint256_MAX, 1) == uint256_MAX - 1 - assert c._uint256_mul(x, y) == x * y - assert_tx_failed(lambda: c._uint256_mul(uint256_MAX, 2)) - assert c._uint256_mul(uint256_MAX, 0) == 0 - assert c._uint256_mul(0, uint256_MAX) == 0 - assert c._uint256_div(x, y) == x // y - assert_tx_failed(lambda: c._uint256_div(uint256_MAX, 0)) - assert c._uint256_div(y, x) == 0 - assert_tx_failed(lambda: c._uint256_div(x, 0)) - assert c._uint256_gt(x, y) is True - assert c._uint256_ge(x, y) is True - assert c._uint256_le(x, y) is False - assert c._uint256_lt(x, y) is False - assert c._uint256_gt(x, x) is False - assert c._uint256_ge(x, x) is True - assert c._uint256_le(x, x) is True - assert c._uint256_lt(x, x) is False - assert c._uint256_lt(y, x) is True - - print("Passed uint256 operation tests") - - -def test_uint256_mod(assert_tx_failed, get_contract_with_gas_estimation): - uint256_code = """ -@external -def _uint256_mod(x: uint256, y: uint256) -> uint256: - return x % y - -@external -def _uint256_addmod(x: uint256, y: uint256, z: uint256) -> uint256: - return uint256_addmod(x, y, z) - -@external -def _uint256_mulmod(x: uint256, y: uint256, z: uint256) -> uint256: - return uint256_mulmod(x, y, z) - """ - - c = get_contract_with_gas_estimation(uint256_code) - - assert c._uint256_mod(3, 2) == 1 - assert c._uint256_mod(34, 32) == 2 - assert_tx_failed(lambda: c._uint256_mod(3, 0)) - assert c._uint256_addmod(1, 2, 2) == 1 - assert c._uint256_addmod(32, 2, 32) == 2 - assert c._uint256_addmod((2 ** 256) - 1, 0, 2) == 1 - assert c._uint256_addmod(2 ** 255, 2 ** 255, 6) == 4 - assert_tx_failed(lambda: c._uint256_addmod(1, 2, 0)) - assert c._uint256_mulmod(3, 1, 2) == 1 - assert c._uint256_mulmod(200, 3, 601) == 600 - assert c._uint256_mulmod(2 ** 255, 1, 3) == 2 - assert c._uint256_mulmod(2 ** 255, 2, 6) == 4 - assert_tx_failed(lambda: c._uint256_mulmod(2, 2, 0)) - - -def test_modmul(get_contract_with_gas_estimation): - modexper = """ -@external -def exponential(base: uint256, exponent: uint256, modulus: uint256) -> uint256: - o: uint256 = 1 - for i in range(256): - o = uint256_mulmod(o, o, modulus) - if bitwise_and(exponent, shift(1, 255 - i)) != 0: - o = uint256_mulmod(o, base, modulus) - return o - """ - - c = get_contract_with_gas_estimation(modexper) - assert c.exponential(3, 5, 100) == 43 - assert c.exponential(2, 997, 997) == 2 - - -def test_uint256_literal(get_contract_with_gas_estimation): - modexper = """ -@external -def test() -> uint256: - o: uint256 = 340282366920938463463374607431768211459 - return o - """ - - c = get_contract_with_gas_estimation(modexper) - assert c.test() == 340282366920938463463374607431768211459 - - -def test_uint256_comparison(get_contract_with_gas_estimation): - code = """ -max_uint_256: public(uint256) - -@external -def __init__(): - self.max_uint_256 = 2*(2**255-1)+1 - -@external -def max_lt() -> (bool): - return 30 < self.max_uint_256 - -@external -def max_lte() -> (bool): - return 30 <= self.max_uint_256 - -@external -def max_gte() -> (bool): - return 30 >= self.max_uint_256 - -@external -def max_gt() -> (bool): - return 30 > self.max_uint_256 - -@external -def max_ne() -> (bool): - return 30 != self.max_uint_256 - """ - - c = get_contract_with_gas_estimation(code) - - assert c.max_lt() is True - assert c.max_lte() is True - assert c.max_gte() is False - assert c.max_gt() is False - assert c.max_ne() is True diff --git a/tests/parser/types/numbers/test_uint8.py b/tests/parser/types/numbers/test_uint8.py deleted file mode 100644 index e37c65a063..0000000000 --- a/tests/parser/types/numbers/test_uint8.py +++ /dev/null @@ -1,206 +0,0 @@ -import itertools as it - -import pytest - -from vyper.codegen.types import parse_integer_typeinfo - - -def test_exponent_base_zero(get_contract): - code = """ -@external -def foo(x: uint8) -> uint8: - return 0 ** x - """ - c = get_contract(code) - assert c.foo(0) == 1 - assert c.foo(1) == 0 - assert c.foo(42) == 0 - assert c.foo(2 ** 8 - 1) == 0 - - -def test_exponent_base_one(get_contract): - code = """ -@external -def foo(x: uint8) -> uint8: - return 1 ** x - """ - c = get_contract(code) - assert c.foo(0) == 1 - assert c.foo(1) == 1 - assert c.foo(42) == 1 - assert c.foo(2 ** 8 - 1) == 1 - - -@pytest.mark.parametrize("base,power", it.product(range(6), repeat=2)) -def test_safe_exponentiation(get_contract, assert_tx_failed, base, power): - code = f""" -@external -def _uint8_exponentiation_base(_power: uint8) -> uint8: - return {base} ** _power - -@external -def _uint8_exponentiation_power(_base: uint8) -> uint8: - return _base ** {power} - """ - - c = get_contract(code) - - if 0 <= base ** power < 2 ** 8 - 1: - # within bounds so ok - assert c._uint8_exponentiation_base(power) == base ** power - assert c._uint8_exponentiation_power(base) == base ** power - else: - # clamps on exponentiation - assert_tx_failed(lambda: c._uint8_exponentiation_base(power)) - assert_tx_failed(lambda: c._uint8_exponentiation_power(base)) - - -def test_uint8_code(assert_tx_failed, get_contract_with_gas_estimation): - uint8_code = """ -@external -def _uint8_add(x: uint8, y: uint8) -> uint8: - return x + y - -@external -def _uint8_sub(x: uint8, y: uint8) -> uint8: - return x - y - -@external -def _uint8_mul(x: uint8, y: uint8) -> uint8: - return x * y - -@external -def _uint8_div(x: uint8, y: uint8) -> uint8: - return x / y - -@external -def _uint8_gt(x: uint8, y: uint8) -> bool: - return x > y - -@external -def _uint8_ge(x: uint8, y: uint8) -> bool: - return x >= y - -@external -def _uint8_lt(x: uint8, y: uint8) -> bool: - return x < y - -@external -def _uint8_le(x: uint8, y: uint8) -> bool: - return x <= y - """ - - c = get_contract_with_gas_estimation(uint8_code) - x = 18 - y = 12 - - uint8_MAX = 2 ** 8 - 1 # Max possible uint8 value - assert c._uint8_add(x, y) == x + y - assert c._uint8_add(0, y) == y - assert c._uint8_add(y, 0) == y - assert_tx_failed(lambda: c._uint8_add(uint8_MAX, uint8_MAX)) - assert c._uint8_sub(x, y) == x - y - assert_tx_failed(lambda: c._uint8_sub(y, x)) - assert c._uint8_sub(0, 0) == 0 - assert c._uint8_sub(uint8_MAX, 0) == uint8_MAX - assert_tx_failed(lambda: c._uint8_sub(1, 2)) - assert c._uint8_sub(uint8_MAX, 1) == uint8_MAX - 1 - assert c._uint8_mul(x, y) == x * y - assert_tx_failed(lambda: c._uint8_mul(uint8_MAX, 2)) - assert c._uint8_mul(uint8_MAX, 0) == 0 - assert c._uint8_mul(0, uint8_MAX) == 0 - assert c._uint8_div(x, y) == x // y - assert_tx_failed(lambda: c._uint8_div(uint8_MAX, 0)) - assert c._uint8_div(y, x) == 0 - assert_tx_failed(lambda: c._uint8_div(x, 0)) - assert c._uint8_gt(x, y) is True - assert c._uint8_ge(x, y) is True - assert c._uint8_le(x, y) is False - assert c._uint8_lt(x, y) is False - assert c._uint8_gt(x, x) is False - assert c._uint8_ge(x, x) is True - assert c._uint8_le(x, x) is True - assert c._uint8_lt(x, x) is False - assert c._uint8_lt(y, x) is True - - print("Passed uint8 operation tests") - - -def test_uint8_literal(get_contract_with_gas_estimation): - modexper = """ -@external -def test() -> uint8: - o: uint8 = 64 - return o - """ - - c = get_contract_with_gas_estimation(modexper) - assert c.test() == 64 - - -def test_uint8_comparison(get_contract_with_gas_estimation): - code = """ -max_uint_8: public(uint8) - -@external -def __init__(): - self.max_uint_8 = 255 - -@external -def max_lt() -> (bool): - return 30 < self.max_uint_8 - -@external -def max_lte() -> (bool): - return 30 <= self.max_uint_8 - -@external -def max_gte() -> (bool): - return 30 >= self.max_uint_8 - -@external -def max_gt() -> (bool): - return 30 > self.max_uint_8 - -@external -def max_ne() -> (bool): - return 30 != self.max_uint_8 - """ - - c = get_contract_with_gas_estimation(code) - - assert c.max_lt() is True - assert c.max_lte() is True - assert c.max_gte() is False - assert c.max_gt() is False - assert c.max_ne() is True - - -# TODO: create a tests/parser/functions/test_convert_to_uint8.py file - - -@pytest.mark.parametrize("in_typ", ["int256", "uint256", "int128", "uint128"]) -def test_uint8_convert_clamps(get_contract, assert_tx_failed, in_typ): - code = f""" -@external -def conversion(_x: {in_typ}) -> uint8: - return convert(_x, uint8) - """ - - c = get_contract(code) - - int_info = parse_integer_typeinfo(in_typ) - - if int_info.is_signed: - # below bounds - for val in [int_info.bounds[0], -(2 ** 127), -3232, -256, -1]: - assert_tx_failed(lambda: c.conversion(val)) - - # above bounds - above_bounds = [256, 3000, 2 ** 126, int_info.bounds[1]] - for val in above_bounds: - assert_tx_failed(lambda: c.conversion(val)) - - # within bounds - for val in [0, 10, 25, 130, 255]: - assert c.conversion(val) == val diff --git a/tests/parser/types/numbers/test_unsigned_ints.py b/tests/parser/types/numbers/test_unsigned_ints.py new file mode 100644 index 0000000000..0f60731035 --- /dev/null +++ b/tests/parser/types/numbers/test_unsigned_ints.py @@ -0,0 +1,225 @@ +import itertools +import operator +import random + +import pytest + +from vyper.codegen.types.types import UNSIGNED_INTEGER_TYPES, parse_integer_typeinfo +from vyper.exceptions import InvalidType, OverflowException, ZeroDivisionException +from vyper.utils import SizeLimits, evm_div, evm_mod, int_bounds + +PARAMS = [] +for t in sorted(UNSIGNED_INTEGER_TYPES): + info = parse_integer_typeinfo(t) + lo, hi = int_bounds(bits=info.bits, signed=info.is_signed) + PARAMS.append((t, lo, hi, info.bits)) + + +@pytest.mark.parametrize("typ,lo,hi,bits", PARAMS) +def test_exponent_base_zero(get_contract, typ, lo, hi, bits): + code = f""" +@external +def foo(x: {typ}) -> {typ}: + return 0 ** x + """ + c = get_contract(code) + assert c.foo(0) == 1 + assert c.foo(1) == 0 + assert c.foo(42) == 0 + assert c.foo(hi) == 0 + + +@pytest.mark.parametrize("typ,lo,hi,bits", PARAMS) +def test_exponent_base_one(get_contract, typ, lo, hi, bits): + code = f""" +@external +def foo(x: {typ}) -> {typ}: + return 1 ** x + """ + c = get_contract(code) + assert c.foo(0) == 1 + assert c.foo(1) == 1 + assert c.foo(42) == 1 + assert c.foo(hi) == 1 + + +ARITHMETIC_OPS = { + "+": operator.add, + "-": operator.sub, + "*": operator.mul, + "/": evm_div, + "%": evm_mod, +} + + +@pytest.mark.parametrize("op", sorted(ARITHMETIC_OPS.keys())) +@pytest.mark.parametrize("typ,lo,hi,bits", PARAMS) +@pytest.mark.fuzzing +def test_arithmetic_thorough( + get_contract, assert_tx_failed, assert_compile_failed, op, typ, lo, hi, bits +): + # both variables + code_1 = f""" +@external +def foo(x: {typ}, y: {typ}) -> {typ}: + return x {op} y + """ + # right is literal + code_2_template = """ +@external +def foo(x: {typ}) -> {typ}: + return x {op} {y} + """ + # left is literal + code_3_template = """ +@external +def foo(y: {typ}) -> {typ}: + return {x} {op} y + """ + # both literals + code_4_template = """ +@external +def foo() -> {typ}: + return {x} {op} {y} + """ + + c = get_contract(code_1) + + fn = ARITHMETIC_OPS[op] + + special_cases = [0, 1, 2, 3, hi // 2 - 1, hi // 2, hi // 2 + 1, hi - 2, hi - 1, hi] + xs = special_cases.copy() + ys = special_cases.copy() + NUM_CASES = 10 + # poor man's fuzzing - hypothesis doesn't make it easy + # with the parametrized strategy + xs += [random.randrange(lo, hi) for _ in range(NUM_CASES)] + ys += [random.randrange(lo, hi) for _ in range(NUM_CASES)] + + # mirror signed integer tests + assert 2 ** (bits - 1) in xs and (2 ** bits) - 1 in ys + + for (x, y) in itertools.product(xs, ys): + expected = fn(x, y) + in_bounds = SizeLimits.in_bounds(typ, expected) + # safediv and safemod disallow divisor == 0 + div_by_zero = y == 0 and op in ("/", "%") + + ok = in_bounds and not div_by_zero + + code_2 = code_2_template.format(typ=typ, op=op, y=y) + code_3 = code_3_template.format(typ=typ, op=op, x=x) + code_4 = code_4_template.format(typ=typ, op=op, x=x, y=y) + + if ok: + assert c.foo(x, y) == expected + assert get_contract(code_2).foo(x) == expected + assert get_contract(code_3).foo(y) == expected + assert get_contract(code_4).foo() == expected + elif div_by_zero: + assert_tx_failed(lambda: c.foo(x, y)) + assert_compile_failed(lambda: get_contract(code_2), ZeroDivisionException) + assert_tx_failed(lambda: get_contract(code_3).foo(y)) + assert_compile_failed(lambda: get_contract(code_4), ZeroDivisionException) + else: + assert_tx_failed(lambda: c.foo(x, y)) + assert_tx_failed(lambda: get_contract(code_2).foo(x)) + assert_tx_failed(lambda: get_contract(code_3).foo(y)) + assert_compile_failed(lambda: get_contract(code_4), (InvalidType, OverflowException)) + + +COMPARISON_OPS = { + "==": operator.eq, + "!=": operator.ne, + ">": operator.gt, + ">=": operator.ge, + "<": operator.lt, + "<=": operator.le, +} + + +@pytest.mark.parametrize("op", sorted(COMPARISON_OPS.keys())) +@pytest.mark.parametrize("typ,lo,hi,bits", PARAMS) +@pytest.mark.fuzzing +def test_comparators(get_contract, op, typ, lo, hi, bits): + code_1 = f""" +@external +def foo(x: {typ}, y: {typ}) -> bool: + return x {op} y + """ + + fn = COMPARISON_OPS[op] + + c = get_contract(code_1) + + # note: constant folding is tested in tests/ast/folding + + special_cases = [0, 1, 2, 3, hi // 2 - 1, hi // 2, hi // 2 + 1, hi - 2, hi - 1, hi] + xs = special_cases.copy() + ys = special_cases.copy() + + for x, y in itertools.product(xs, ys): + expected = fn(x, y) + assert c.foo(x, y) is expected + + +# TODO move to tests/parser/functions/test_mulmod.py and test_addmod.py +def test_uint256_mod(assert_tx_failed, get_contract_with_gas_estimation): + uint256_code = """ +@external +def _uint256_addmod(x: uint256, y: uint256, z: uint256) -> uint256: + return uint256_addmod(x, y, z) + +@external +def _uint256_mulmod(x: uint256, y: uint256, z: uint256) -> uint256: + return uint256_mulmod(x, y, z) + """ + + c = get_contract_with_gas_estimation(uint256_code) + + assert c._uint256_addmod(1, 2, 2) == 1 + assert c._uint256_addmod(32, 2, 32) == 2 + assert c._uint256_addmod((2 ** 256) - 1, 0, 2) == 1 + assert c._uint256_addmod(2 ** 255, 2 ** 255, 6) == 4 + assert_tx_failed(lambda: c._uint256_addmod(1, 2, 0)) + assert c._uint256_mulmod(3, 1, 2) == 1 + assert c._uint256_mulmod(200, 3, 601) == 600 + assert c._uint256_mulmod(2 ** 255, 1, 3) == 2 + assert c._uint256_mulmod(2 ** 255, 2, 6) == 4 + assert_tx_failed(lambda: c._uint256_mulmod(2, 2, 0)) + + +def test_uint256_modmul(get_contract_with_gas_estimation): + modexper = """ +@external +def exponential(base: uint256, exponent: uint256, modulus: uint256) -> uint256: + o: uint256 = 1 + for i in range(256): + o = uint256_mulmod(o, o, modulus) + if bitwise_and(exponent, shift(1, 255 - i)) != 0: + o = uint256_mulmod(o, base, modulus) + return o + """ + + c = get_contract_with_gas_estimation(modexper) + assert c.exponential(3, 5, 100) == 43 + assert c.exponential(2, 997, 997) == 2 + + +@pytest.mark.parametrize("typ,lo,hi,bits", PARAMS) +def test_uint_literal(get_contract, assert_compile_failed, typ, lo, hi, bits): + good_cases = [0, 1, 2, 3, hi // 2 - 1, hi // 2, hi // 2 + 1, hi - 1, hi] + bad_cases = [-1, -2, -3, -hi // 2, -hi + 1, -hi] + code_template = """ +@external +def test() -> {typ}: + o: {typ} = {val} + return o + """ + + for val in good_cases: + c = get_contract(code_template.format(typ=typ, val=val)) + assert c.test() == val + + for val in bad_cases: + assert_compile_failed(lambda: get_contract(code_template.format(typ=typ, val=val))) diff --git a/vyper/builtin_functions/convert.py b/vyper/builtin_functions/convert.py index a217a995ba..ad0aafea21 100644 --- a/vyper/builtin_functions/convert.py +++ b/vyper/builtin_functions/convert.py @@ -142,13 +142,13 @@ def _fixed_to_int(arg, out_typ): # block inputs which are out of bounds before truncation. # e.g., convert(255.1, uint8) should revert or fail to compile. out_lo, out_hi = out_info.bounds - out_lo = int(out_lo * DIVISOR) - out_hi = int(out_hi * DIVISOR) + out_lo = out_lo * DIVISOR + out_hi = out_hi * DIVISOR clamped_arg = _clamp_numeric_convert(arg, arg_info.bounds, (out_lo, out_hi), arg_info.is_signed) assert arg_info.is_signed, "should use unsigned div" # stub in case we ever add ufixed - return IRnode.from_list(["sdiv", clamped_arg, int(DIVISOR)], typ=out_typ) + return IRnode.from_list(["sdiv", clamped_arg, DIVISOR], typ=out_typ) # promote from int to fixed point decimal @@ -160,12 +160,12 @@ def _int_to_fixed(arg, out_typ): # block inputs which are out of bounds before promotion out_lo, out_hi = out_info.bounds - out_lo = round_towards_zero(out_lo / DIVISOR) - out_hi = round_towards_zero(out_hi / DIVISOR) + out_lo = round_towards_zero(out_lo / decimal.Decimal(DIVISOR)) + out_hi = round_towards_zero(out_hi / decimal.Decimal(DIVISOR)) clamped_arg = _clamp_numeric_convert(arg, arg_info.bounds, (out_lo, out_hi), arg_info.is_signed) - return IRnode.from_list(["mul", clamped_arg, int(DIVISOR)], typ=out_typ) + return IRnode.from_list(["mul", clamped_arg, DIVISOR], typ=out_typ) # clamp for dealing with conversions between int types (from arg to dst) diff --git a/vyper/codegen/arithmetic.py b/vyper/codegen/arithmetic.py new file mode 100644 index 0000000000..55b5087e2c --- /dev/null +++ b/vyper/codegen/arithmetic.py @@ -0,0 +1,352 @@ +import decimal +import math + +from vyper.codegen.core import clamp, clamp_basetype +from vyper.codegen.ir_node import IRnode +from vyper.codegen.types import BaseType, is_decimal_type, is_integer_type +from vyper.evm.opcodes import version_check +from vyper.exceptions import CompilerPanic, TypeCheckFailure, UnimplementedException + + +def calculate_largest_power(a: int, num_bits: int, is_signed: bool) -> int: + """ + For a given base `a`, compute the maximum power `b` that will not + produce an overflow in the equation `a ** b` + + Arguments + --------- + a : int + Base value for the equation `a ** b` + num_bits : int + The maximum number of bits that the resulting value must fit in + is_signed : bool + Is the operation being performed on signed integers? + + Returns + ------- + int + Largest possible value for `b` where the result does not overflow + `num_bits` + """ + if num_bits % 8: + raise CompilerPanic("Type is not a modulo of 8") + + value_bits = num_bits - (1 if is_signed else 0) + if a >= 2 ** value_bits: + raise TypeCheckFailure("Value is too large and will always throw") + elif a < -(2 ** value_bits): + raise TypeCheckFailure("Value is too small and will always throw") + + a_is_negative = a < 0 + a = abs(a) # No longer need to know if it's signed or not + + if a in (0, 1): + raise CompilerPanic("Exponential operation is useless!") + + # NOTE: There is an edge case if `a` were left signed where the following + # operation would not work (`ln(a)` is undefined if `a <= 0`) + b = int(decimal.Decimal(value_bits) / (decimal.Decimal(a).ln() / decimal.Decimal(2).ln())) + if b <= 1: + return 1 # Value is assumed to be in range, therefore power of 1 is max + + # Do a bit of iteration to ensure we have the exact number + + # CMC 2022-05-06 (TODO we should be able to this with algebra + # instead of looping): + # a ** x == 2**value_bits + # x ln(a) = ln(2**value_bits) + # x = ln(2**value_bits) / ln(a) + + num_iterations = 0 + while a ** (b + 1) < 2 ** value_bits: + b += 1 + num_iterations += 1 + assert num_iterations < 10000 + while a ** b >= 2 ** value_bits: + b -= 1 + num_iterations += 1 + assert num_iterations < 10000 + + # Edge case: If a is negative and the values of a and b are such that: + # (a) ** (b + 1) == -(2 ** value_bits) + # we can actually squeak one more out of it because it's on the edge + if a_is_negative and (-a) ** (b + 1) == -(2 ** value_bits): # NOTE: a = abs(a) + return b + 1 + else: + return b # Exact + + +def calculate_largest_base(b: int, num_bits: int, is_signed: bool) -> int: + """ + For a given power `b`, compute the maximum base `a` that will not produce an + overflow in the equation `a ** b` + + Arguments + --------- + b : int + Power value for the equation `a ** b` + num_bits : int + The maximum number of bits that the resulting value must fit in + is_signed : bool + Is the operation being performed on signed integers? + + Returns + ------- + int + Largest possible value for `a` where the result does not overflow + `num_bits` + """ + if num_bits % 8: + raise CompilerPanic("Type is not a modulo of 8") + if b < 0: + raise TypeCheckFailure("Cannot calculate negative exponents") + + value_bits = num_bits - (1 if is_signed else 0) + if b > value_bits: + raise TypeCheckFailure("Value is too large and will always throw") + elif b < 2: + return 2 ** value_bits - 1 # Maximum value for type + + # CMC 2022-05-06 TODO we should be able to do this with algebra + # instead of looping): + # x ** b == 2**value_bits + # b ln(x) == ln(2**value_bits) + # ln(x) == ln(2**value_bits) / b + # x == exp( ln(2**value_bits) / b) + + # Estimate (up to ~39 digits precision required) + a = math.ceil(2 ** (decimal.Decimal(value_bits) / decimal.Decimal(b))) + # Do a bit of iteration to ensure we have the exact number + num_iterations = 0 + while (a + 1) ** b < 2 ** value_bits: + a += 1 + num_iterations += 1 + assert num_iterations < 10000 + while a ** b >= 2 ** value_bits: + a -= 1 + num_iterations += 1 + assert num_iterations < 10000 + return a + + +# def safe_add(x: IRnode, y: IRnode) -> IRnode: +def safe_add(x, y): + assert x.typ is not None and x.typ == y.typ and isinstance(x.typ, BaseType) + num_info = x.typ._num_info + + res = IRnode.from_list(["add", x, y], typ=x.typ.typ) + + if num_info.bits < 256: + return clamp_basetype(res) + + # bits == 256 + with res.cache_when_complex("ans") as (b1, res): + if num_info.is_signed: + # if r < 0: + # ans < l + # else: + # ans >= l # aka (iszero (ans < l)) + # aka: (r < 0) == (ans < l) + ok = ["eq", ["slt", y, 0], ["slt", res, x]] + else: + # note this is "equivalent" to the unsigned form + # of the above (because y < 0 == False) + # ["eq", ["lt", y, 0], ["lt", res, x]] + # TODO push down into optimizer rules. + ok = ["ge", res, x] + + ret = IRnode.from_list(["seq", ["assert", ok], res]) + return b1.resolve(ret) + + +# def safe_sub(x: IRnode, y: IRnode) -> IRnode: +def safe_sub(x, y): + num_info = x.typ._num_info + + res = IRnode.from_list(["sub", x, y], typ=x.typ.typ) + + if num_info.bits < 256: + return clamp_basetype(res) + + # bits == 256 + with res.cache_when_complex("ans") as (b1, res): + if num_info.is_signed: + # if r < 0: + # ans > l + # else: + # ans <= l # aka (iszero (ans > l)) + # aka: (r < 0) == (ans > l) + ok = ["eq", ["slt", y, 0], ["sgt", res, x]] + else: + # note this is "equivalent" to the unsigned form + # of the above (because y < 0 == False) + # ["eq", ["lt", y, 0], ["gt", res, x]] + # TODO push down into optimizer rules. + ok = ["le", res, x] + + ret = IRnode.from_list(["seq", ["assert", ok], res]) + return b1.resolve(ret) + + +# def safe_mul(x: IRnode, y: IRnode) -> IRnode: +def safe_mul(x, y): + # precondition: x.typ.typ == y.typ.typ + num_info = x.typ._num_info + + # optimizer rules work better for the safemul checks below + # if second operand is literal + if x.is_literal: + tmp = x + x = y + y = tmp + + res = IRnode.from_list(["mul", x, y], typ=x.typ.typ) + + DIV = "sdiv" if num_info.is_signed else "div" + + with res.cache_when_complex("ans") as (b1, res): + + ok = [1] # True + + if num_info.bits > 128: # check overflow mod 256 + # assert (res/y == x | y == 0) + ok = ["or", ["eq", [DIV, res, y], x], ["iszero", y]] + + # int256 + if num_info.is_signed and num_info.bits == 256: + # special case: + # in the above sdiv check, if (r==-1 and l==-2**255), + # -2**255 / -1 will return -2**255. + # need to check: not (r == -1 and l == -2**255) + if version_check(begin="constantinople"): + upper_bound = ["shl", 255, 1] + else: + upper_bound = -(2 ** 255) + + check_x = ["ne", x, upper_bound] + check_y = ["ne", ["not", y], 0] + + if not x.is_literal and not y.is_literal: + # TODO can simplify this condition? + ok = ["and", ok, ["or", check_x, check_y]] + + # TODO push some of this constant folding into optimizer + elif x.is_literal and x.value == -(2 ** 255): + ok = ["and", ok, check_y] + elif y.is_literal and y.value == -1: + ok = ["and", ok, check_x] + else: + # x or y is a literal, and we have determined it is + # not an evil value + pass + + if is_decimal_type(res.typ): + res = IRnode.from_list([DIV, res, num_info.divisor], typ=res.typ) + + # check overflow mod + # NOTE: if 128 < bits < 256, `x * y` could be between + # MAX_ and 2**256 OR it could overflow past 2**256. + # so, we check for overflow in mod 256 AS WELL AS mod + # (if bits == 256, clamp_basetype is a no-op) + res = clamp_basetype(res) + + res = IRnode.from_list(["seq", ["assert", ok], res], typ=res.typ) + + return b1.resolve(res) + + +# def safe_div(x: IRnode, y: IRnode) -> IRnode: +def safe_div(x, y): + num_info = x.typ._num_info + typ = x.typ + + ok = [1] # true + + if is_decimal_type(x.typ): + lo, hi = num_info.bounds + if max(abs(lo), abs(hi)) * num_info.divisor > 2 ** 256 - 1: + # stub to prevent us from adding fixed point numbers we don't know + # how to deal with + raise UnimplementedException("safe_mul for decimal{num_info.bits}x{num_info.decimals}") + x = ["mul", x, num_info.divisor] + + DIV = "sdiv" if num_info.is_signed else "div" + res = IRnode.from_list([DIV, x, clamp("gt", y, 0)], typ=typ) + with res.cache_when_complex("res") as (b1, res): + + # TODO: refactor this condition / push some things into the optimizer + if num_info.is_signed and num_info.bits == 256: + if version_check(begin="constantinople"): + upper_bound = ["shl", 255, 1] + else: + upper_bound = -(2 ** 255) + + if not x.is_literal and not y.typ.is_literal: + ok = ["or", ["ne", y, ["not", 0]], ["ne", x, upper_bound]] + # TODO push these rules into the optimizer + elif x.is_literal and x.value == -(2 ** 255): + ok = ["ne", y, ["not", 0]] + elif y.is_literal and y.value == -1: + ok = ["ne", x, upper_bound] + else: + # x or y is a literal, and not an evil value. + pass + + elif num_info.is_signed and is_integer_type(typ): + lo, hi = num_info.bounds + # we need to throw on min_value(typ) / -1, + # but we can skip if one of the operands is a literal and not + # the evil value + can_skip_clamp = (x.is_literal and x.value != lo) or (y.is_literal and y.value != -1) + if not can_skip_clamp: + # clamp_basetype has fewer ops than the int256 rule. + res = clamp_basetype(res) + + elif is_decimal_type(typ): + # always clamp decimals, since decimal division can actually + # result in something larger than either operand (e.g. 1.0 / 0.1) + # TODO maybe use safe_mul + res = clamp_basetype(res) + + check = ["assert", ok] + return IRnode.from_list(b1.resolve(["seq", check, res])) + + +# def safe_mod(x: IRnode, y: IRnode) -> IRnode: +def safe_mod(x, y): + num_info = x.typ._num_info + MOD = "smod" if num_info.is_signed else "mod" + return IRnode.from_list([MOD, x, clamp("gt", y, 0)]) + + +# def safe_pow(x: IRnode, y: IRnode) -> IRnode: +def safe_pow(x, y): + num_info = x.typ._num_info + if not is_integer_type(x.typ): + # type checker should have caught this + raise TypeCheckFailure("non-integer pow") + + if x.is_literal: + # cannot pass 1 or 0 to `calculate_largest_power` + if x.value == 1: + return IRnode.from_list([1]) + if x.value == 0: + return IRnode.from_list(["iszero", y]) + + upper_bound = calculate_largest_power(x.value, num_info.bits, num_info.is_signed) + 1 + # for signed integers, this also prevents negative values + ok = ["lt", y, upper_bound] + + elif y.is_literal: + upper_bound = calculate_largest_base(y.value, num_info.bits, num_info.is_signed) + 1 + if num_info.is_signed: + ok = ["and", ["slt", x, upper_bound], ["sgt", x, -upper_bound]] + else: + ok = ["lt", x, upper_bound] + else: + # `a ** b` where neither `a` or `b` are known + # TODO this is currently unreachable, once we implement a way to do it safely + # remove the check in `vyper/context/types/value/numeric.py` + return + + return IRnode.from_list(["seq", ["assert", ok], ["exp", x, y]]) diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index 06b10680f6..c22b4d4ee7 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -935,22 +935,24 @@ def clamp_basetype(ir_node): if is_integer_type(t) or is_decimal_type(t): if t._num_info.bits == 256: - return ir_node + ret = ir_node else: - return int_clamp(ir_node, t._num_info.bits, signed=t._num_info.is_signed) + ret = int_clamp(ir_node, t._num_info.bits, signed=t._num_info.is_signed) - if is_bytes_m_type(t): + elif is_bytes_m_type(t): if t._bytes_info.m == 32: - return ir_node # special case, no clamp. + ret = ir_node # special case, no clamp. else: - return bytes_clamp(ir_node, t._bytes_info.m) + ret = bytes_clamp(ir_node, t._bytes_info.m) - if t.typ in ("address",): - return int_clamp(ir_node, 160) - if t.typ in ("bool",): - return int_clamp(ir_node, 1) + elif t.typ in ("address",): + ret = int_clamp(ir_node, 160) + elif t.typ in ("bool",): + ret = int_clamp(ir_node, 1) + else: # pragma: nocover + raise CompilerPanic(f"{t} passed to clamp_basetype") - raise CompilerPanic(f"{t} passed to clamp_basetype") # pragma: notest + return IRnode.from_list(ret, typ=ir_node.typ) def int_clamp(ir_node, bits, signed=False): diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 113ccd4c67..8e5ce03a64 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -1,12 +1,12 @@ import decimal import math +import vyper.codegen.arithmetic as arithmetic from vyper import ast as vy_ast from vyper.address_space import DATA, IMMUTABLES, MEMORY, STORAGE from vyper.codegen import external_call, self_call from vyper.codegen.core import ( clamp, - clamp_basetype, ensure_in_memory, get_dyn_array_count, get_element_ptr, @@ -42,7 +42,6 @@ StructureException, TypeCheckFailure, TypeMismatch, - UnimplementedException, ) from vyper.utils import ( DECIMAL_DIVISOR, @@ -55,113 +54,6 @@ ENVIRONMENT_VARIABLES = {"block", "msg", "tx", "chain"} -def calculate_largest_power(a: int, num_bits: int, is_signed: bool) -> int: - """ - For a given base `a`, compute the maximum power `b` that will not - produce an overflow in the equation `a ** b` - - Arguments - --------- - a : int - Base value for the equation `a ** b` - num_bits : int - The maximum number of bits that the resulting value must fit in - is_signed : bool - Is the operation being performed on signed integers? - - Returns - ------- - int - Largest possible value for `b` where the result does not overflow - `num_bits` - """ - if num_bits % 8: - raise CompilerPanic("Type is not a modulo of 8") - - value_bits = num_bits - (1 if is_signed else 0) - if a >= 2 ** value_bits: - raise TypeCheckFailure("Value is too large and will always throw") - elif a < -(2 ** value_bits): - raise TypeCheckFailure("Value is too small and will always throw") - - a_is_negative = a < 0 - a = abs(a) # No longer need to know if it's signed or not - if a in (0, 1): - raise CompilerPanic("Exponential operation is useless!") - - # NOTE: There is an edge case if `a` were left signed where the following - # operation would not work (`ln(a)` is undefined if `a <= 0`) - b = int(decimal.Decimal(value_bits) / (decimal.Decimal(a).ln() / decimal.Decimal(2).ln())) - if b <= 1: - return 1 # Value is assumed to be in range, therefore power of 1 is max - - # Do a bit of iteration to ensure we have the exact number - num_iterations = 0 - while a ** (b + 1) < 2 ** value_bits: - b += 1 - num_iterations += 1 - assert num_iterations < 10000 - while a ** b >= 2 ** value_bits: - b -= 1 - num_iterations += 1 - assert num_iterations < 10000 - - # Edge case: If a is negative and the values of a and b are such that: - # (a) ** (b + 1) == -(2 ** value_bits) - # we can actually squeak one more out of it because it's on the edge - if a_is_negative and (-a) ** (b + 1) == -(2 ** value_bits): # NOTE: a = abs(a) - return b + 1 - else: - return b # Exact - - -def calculate_largest_base(b: int, num_bits: int, is_signed: bool) -> int: - """ - For a given power `b`, compute the maximum base `a` that will not produce an - overflow in the equation `a ** b` - - Arguments - --------- - b : int - Power value for the equation `a ** b` - num_bits : int - The maximum number of bits that the resulting value must fit in - is_signed : bool - Is the operation being performed on signed integers? - - Returns - ------- - int - Largest possible value for `a` where the result does not overflow - `num_bits` - """ - if num_bits % 8: - raise CompilerPanic("Type is not a modulo of 8") - if b < 0: - raise TypeCheckFailure("Cannot calculate negative exponents") - - value_bits = num_bits - (1 if is_signed else 0) - if b > value_bits: - raise TypeCheckFailure("Value is too large and will always throw") - elif b < 2: - return 2 ** value_bits - 1 # Maximum value for type - - # Estimate (up to ~39 digits precision required) - a = math.ceil(2 ** (decimal.Decimal(value_bits) / decimal.Decimal(b))) - # Do a bit of iteration to ensure we have the exact number - num_iterations = 0 - while (a + 1) ** b < 2 ** value_bits: - a += 1 - num_iterations += 1 - assert num_iterations < 10000 - while a ** b >= 2 ** value_bits: - a -= 1 - num_iterations += 1 - assert num_iterations < 10000 - - return a - - class Expr: # TODO: Once other refactors are made reevaluate all inline imports @@ -458,251 +350,31 @@ def parse_BinOp(self): if not is_numeric_type(left.typ) or not is_numeric_type(right.typ): return - types = {left.typ.typ, right.typ.typ} - literals = {left.typ.is_literal, right.typ.is_literal} - - # If one value of the operation is a literal, we recast it to match the non-literal type. - # We know this is OK because types were already verified in the actual typechecking pass. - # This is a temporary solution to not break codegen while we work toward removing types - # altogether at this stage of complition. @iamdefinitelyahuman - if literals == {True, False} and len(types) > 1 and "decimal" not in types: - if left.typ.is_literal and SizeLimits.in_bounds(right.typ.typ, left.value): - left = IRnode.from_list(left.value, typ=BaseType(right.typ.typ, is_literal=True)) - elif right.typ.is_literal and SizeLimits.in_bounds(left.typ.typ, right.value): - right = IRnode.from_list(right.value, typ=BaseType(left.typ.typ, is_literal=True)) - ltyp, rtyp = left.typ.typ, right.typ.typ # Sanity check - ensure that we aren't dealing with different types # This should be unreachable due to the type check pass assert ltyp == rtyp, f"unreachable, {ltyp}!={rtyp}, {self.expr}" - arith = None - if isinstance(self.expr.op, (vy_ast.Add, vy_ast.Sub)): - new_typ = BaseType(ltyp) - - if ltyp == "uint256": - if isinstance(self.expr.op, vy_ast.Add): - # safeadd - arith = ["seq", ["assert", ["ge", ["add", "l", "r"], "l"]], ["add", "l", "r"]] - - elif isinstance(self.expr.op, vy_ast.Sub): - # safesub - arith = ["seq", ["assert", ["ge", "l", "r"]], ["sub", "l", "r"]] - - elif ltyp == "int256": - if isinstance(self.expr.op, vy_ast.Add): - op, comp1, comp2 = "add", "sge", "slt" - else: - op, comp1, comp2 = "sub", "sle", "sgt" - - if right.typ.is_literal: - if right.value >= 0: - arith = ["seq", ["assert", [comp1, [op, "l", "r"], "l"]], [op, "l", "r"]] - else: - arith = ["seq", ["assert", [comp2, [op, "l", "r"], "l"]], [op, "l", "r"]] - else: - arith = [ - "with", - "ans", - [op, "l", "r"], - [ - "seq", - [ - "assert", - [ - "or", - ["and", ["sge", "r", 0], [comp1, "ans", "l"]], - ["and", ["slt", "r", 0], [comp2, "ans", "l"]], - ], - ], - "ans", - ], - ] - - elif ltyp in ("decimal", "int128", "uint8"): - op = "add" if isinstance(self.expr.op, vy_ast.Add) else "sub" - arith = [op, "l", "r"] - - elif isinstance(self.expr.op, vy_ast.Mult): - new_typ = BaseType(ltyp) - if ltyp == "uint256": - arith = [ - "with", - "ans", - ["mul", "l", "r"], - [ - "seq", - ["assert", ["or", ["eq", ["div", "ans", "l"], "r"], ["iszero", "l"]]], - "ans", - ], - ] - - elif ltyp == "int256": - if version_check(begin="constantinople"): - upper_bound = ["shl", 255, 1] - else: - upper_bound = -(2 ** 255) - if not left.typ.is_literal and not right.typ.is_literal: - bounds_check = [ - "assert", - ["or", ["ne", "l", ["not", 0]], ["ne", "r", upper_bound]], - ] - elif left.typ.is_literal and left.value == -1: - bounds_check = ["assert", ["ne", "r", upper_bound]] - elif right.typ.is_literal and right.value == -(2 ** 255): - bounds_check = ["assert", ["ne", "l", ["not", 0]]] - else: - bounds_check = "pass" - arith = [ - "with", - "ans", - ["mul", "l", "r"], - [ - "seq", - bounds_check, - ["assert", ["or", ["eq", ["sdiv", "ans", "l"], "r"], ["iszero", "l"]]], - "ans", - ], - ] - - elif ltyp in ("int128", "uint8"): - arith = ["mul", "l", "r"] - - elif ltyp == "decimal": - arith = [ - "with", - "ans", - ["mul", "l", "r"], - [ - "seq", - ["assert", ["or", ["eq", ["sdiv", "ans", "l"], "r"], ["iszero", "l"]]], - ["sdiv", "ans", DECIMAL_DIVISOR], - ], - ] - - elif isinstance(self.expr.op, vy_ast.Div): - if right.typ.is_literal and right.value == 0: - return - - new_typ = BaseType(ltyp) - - if right.typ.is_literal: - divisor = "r" + out_typ = BaseType(ltyp) + + with left.cache_when_complex("x") as (b1, x), right.cache_when_complex("y") as (b2, y): + if isinstance(self.expr.op, vy_ast.Add): + ret = arithmetic.safe_add(x, y) + elif isinstance(self.expr.op, vy_ast.Sub): + ret = arithmetic.safe_sub(x, y) + elif isinstance(self.expr.op, vy_ast.Mult): + ret = arithmetic.safe_mul(x, y) + elif isinstance(self.expr.op, vy_ast.Div): + ret = arithmetic.safe_div(x, y) + elif isinstance(self.expr.op, vy_ast.Mod): + ret = arithmetic.safe_mod(x, y) + elif isinstance(self.expr.op, vy_ast.Pow): + ret = arithmetic.safe_pow(x, y) else: - # only apply the non-zero clamp when r is not a constant - divisor = clamp("gt", "r", 0) - - if ltyp in ("uint8", "uint256"): - arith = ["div", "l", divisor] - - elif ltyp == "int256": - if version_check(begin="constantinople"): - upper_bound = ["shl", 255, 1] - else: - upper_bound = -(2 ** 255) - if not left.typ.is_literal and not right.typ.is_literal: - bounds_check = [ - "assert", - ["or", ["ne", "r", ["not", 0]], ["ne", "l", upper_bound]], - ] - elif left.typ.is_literal and left.value == -(2 ** 255): - bounds_check = ["assert", ["ne", "r", ["not", 0]]] - elif right.typ.is_literal and right.value == -1: - bounds_check = ["assert", ["ne", "l", upper_bound]] - else: - bounds_check = "pass" - arith = ["seq", bounds_check, ["sdiv", "l", divisor]] - - elif ltyp == "int128": - arith = ["sdiv", "l", divisor] - - elif ltyp == "decimal": - arith = ["sdiv", ["mul", "l", DECIMAL_DIVISOR], divisor] - - elif isinstance(self.expr.op, vy_ast.Mod): - if right.typ.is_literal and right.value == 0: - return - - new_typ = BaseType(ltyp) - - if right.typ.is_literal: - divisor = "r" - else: - # only apply the non-zero clamp when r is not a constant - divisor = clamp("gt", "r", 0) - - if ltyp in ("uint8", "uint256"): - arith = ["mod", "l", divisor] - else: - arith = ["smod", "l", divisor] - - elif isinstance(self.expr.op, vy_ast.Pow): - new_typ = BaseType(ltyp) - - # TODO optimizer rule for special cases - if self.expr.left.get("value") == 1: - return IRnode.from_list([1], typ=new_typ) - if self.expr.left.get("value") == 0: - return IRnode.from_list(["iszero", right], typ=new_typ) - - if ltyp == "int128": - is_signed = True - num_bits = 128 - elif ltyp == "int256": - is_signed = True - num_bits = 256 - elif ltyp == "uint8": - is_signed = False - num_bits = 8 - else: - is_signed = False - num_bits = 256 - - if isinstance(self.expr.left, vy_ast.Int): - value = self.expr.left.value - upper_bound = calculate_largest_power(value, num_bits, is_signed) + 1 - # for signed integers, this also prevents negative values - clamp_cond = ["lt", right, upper_bound] - return IRnode.from_list( - ["seq", ["assert", clamp_cond], ["exp", left, right]], typ=new_typ - ) - elif isinstance(self.expr.right, vy_ast.Int): - value = self.expr.right.value - upper_bound = calculate_largest_base(value, num_bits, is_signed) + 1 - if is_signed: - clamp_cond = ["and", ["slt", left, upper_bound], ["sgt", left, -upper_bound]] - else: - clamp_cond = ["lt", left, upper_bound] - return IRnode.from_list( - ["seq", ["assert", clamp_cond], ["exp", left, right]], typ=new_typ - ) - else: - # `a ** b` where neither `a` or `b` are known - # TODO this is currently unreachable, once we implement a way to do it safely - # remove the check in `vyper/context/types/value/numeric.py` - return + return # raises - if arith is None: - op_str = self.expr.op._pretty - raise UnimplementedException(f"Not implemented: {ltyp} {op_str} {rtyp}", self.expr.op) - - arith = IRnode.from_list(arith, typ=new_typ) - - p = [ - "with", - "l", - left, - [ - "with", - "r", - right, - # note clamp_basetype is a noop on [u]int256 - # note: clamp_basetype throws on unclampable input - clamp_basetype(arith), - ], - ] - return IRnode.from_list(p, typ=new_typ) + return IRnode.from_list(b1.resolve(b2.resolve(ret)), typ=out_typ) def build_in_comparator(self): left = Expr(self.expr.left, self.context).ir_node @@ -882,8 +554,6 @@ def parse_UnaryOp(self): assert operand.typ._num_info.is_signed # Clamp on minimum signed integer value as we cannot negate that # value (all other integer values are fine) - # CMC 2022-04-06 maybe this could be branchless with: - # max(val, 0 - val) min_int_val, _ = operand.typ._num_info.bounds return IRnode.from_list(["sub", 0, clamp("sgt", operand, min_int_val)], typ=operand.typ) diff --git a/vyper/codegen/types/types.py b/vyper/codegen/types/types.py index 4994ee14ec..89cdbe3d07 100644 --- a/vyper/codegen/types/types.py +++ b/vyper/codegen/types/types.py @@ -104,18 +104,17 @@ class DecimalTypeInfo(NumericTypeInfo): decimals: int @property - def divisor(self) -> Decimal: - # TODO reconsider if this API should return int - return Decimal(10 ** self.decimals) + def divisor(self) -> int: + return 10 ** self.decimals @property def epsilon(self) -> Decimal: - return 1 / self.divisor + return 1 / Decimal(self.divisor) @property def decimal_bounds(self) -> Tuple[Decimal, Decimal]: lo, hi = self.bounds - DIVISOR = self.divisor + DIVISOR = Decimal(self.divisor) return lo / DIVISOR, hi / DIVISOR diff --git a/vyper/ir/optimizer.py b/vyper/ir/optimizer.py index d670be25ea..9ffa60eecd 100644 --- a/vyper/ir/optimizer.py +++ b/vyper/ir/optimizer.py @@ -53,6 +53,7 @@ def _deep_contains(node_or_list, node): "sdiv": (evm_div, "/", SIGNED), "mod": (evm_mod, "%", UNSIGNED), "smod": (evm_mod, "%", SIGNED), + "exp": (operator.pow, "**", UNSIGNED), "eq": (operator.eq, "==", UNSIGNED), "ne": (operator.ne, "!=", UNSIGNED), "lt": (operator.lt, "<", UNSIGNED), @@ -154,9 +155,13 @@ def _conservative_eq(x, y): new_val = args[0].value new_args = args[0].args - elif binop in {"sub", "xor"} and _conservative_eq(args[0], args[1]): - # x - x == x ^ x == 0 - new_val = 0 + elif binop in {"sub", "xor", "eq", "ne"} and _conservative_eq(args[0], args[1]): + if binop == "eq": + # (x == x) == 1 + new_val = 1 + else: + # x - x == x ^ x == x != x == 0 + new_val = 0 new_args = [] # TODO associativity rules @@ -178,6 +183,21 @@ def _conservative_eq(x, y): new_val = "sub" new_args = [0, args[0]] + elif binop == "exp": + # n ** 0 == 1 (forall n) + # 1 ** n == 1 + if _int(args[1]) == 0 or _int(args[0]) == 1: + new_val = 1 + new_args = [] + # 0 ** n == (1 if n == 0 else 0) + if _int(args[0]) == 0: + new_val = "iszero" + new_args = [args[1]] + # n ** 1 == n + if _int(args[1]) == 1: + new_val = args[0].value + new_args = args[0].args + # maybe OK: # elif binop == "div" and _int(args[1], UNSIGNED) == MAX_UINT256: # # (div x (2**256 - 1)) == (eq x (2**256 - 1)) @@ -249,7 +269,8 @@ def _conservative_eq(x, y): # x < 1 => x <= 0 new_rhs = rhs + 1 if op_is_gt else rhs - 1 - if _wrap(new_rhs) != new_rhs: + in_bounds = _wrap(new_rhs) == new_rhs + if not in_bounds: # always false. ex. (gt x MAX_UINT256) # note that the wrapped version (ge x 0) is always true. new_val = 0 @@ -286,8 +307,8 @@ def _conservative_eq(x, y): new_val = "iszero" new_args = [args[0]] - # gt x 0 => x != 0 - elif binop == "gt" and _int(args[1]) == 0: + # gt x 0 == x != 0 == (iszero (iszero x)) + elif binop in ("gt", "ne") and _int(args[1]) == 0: new_val = "iszero" new_args = [["iszero", args[0]]] diff --git a/vyper/utils.py b/vyper/utils.py index 10a09e4da1..c4dd170691 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -174,6 +174,7 @@ def calc_mem_gas(memsize): # A decimal value can store multiples of 1/DECIMAL_DIVISOR MAX_DECIMAL_PLACES = 10 DECIMAL_DIVISOR = 10 ** MAX_DECIMAL_PLACES +DECIMAL_EPSILON = decimal.Decimal(1) / DECIMAL_DIVISOR def int_bounds(signed, bits):