From 379fac9cf6ca0657c9b80926f53af4478f823fd2 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 6 May 2022 11:39:15 +0200 Subject: [PATCH] move pow to arithmetic.py --- tests/fuzzing/test_exponents.py | 2 +- vyper/codegen/arithmetic.py | 142 ++++++++++++++++++++++++++- vyper/codegen/expr.py | 169 ++------------------------------ 3 files changed, 149 insertions(+), 164 deletions(-) diff --git a/tests/fuzzing/test_exponents.py b/tests/fuzzing/test_exponents.py index 0ac1285b56a..e937f023e4a 100644 --- a/tests/fuzzing/test_exponents.py +++ b/tests/fuzzing/test_exponents.py @@ -2,7 +2,7 @@ from hypothesis import example, given, settings from hypothesis import strategies as st -from vyper.codegen.expr import calculate_largest_base, calculate_largest_power +from vyper.codegen.arithmetic import calculate_largest_base, calculate_largest_power @pytest.mark.fuzzing diff --git a/vyper/codegen/arithmetic.py b/vyper/codegen/arithmetic.py index fa8fb3d29f1..f406cbaf9fc 100644 --- a/vyper/codegen/arithmetic.py +++ b/vyper/codegen/arithmetic.py @@ -5,6 +5,127 @@ from vyper.exceptions import CompilerPanic +def calculate_largest_power(a: int, num_bits: int, is_signed: bool) -> int: + """ + For a given base `a`, compute the maximum power `b` that will not + produce an overflow in the equation `a ** b` + + Arguments + --------- + a : int + Base value for the equation `a ** b` + num_bits : int + The maximum number of bits that the resulting value must fit in + is_signed : bool + Is the operation being performed on signed integers? + + Returns + ------- + int + Largest possible value for `b` where the result does not overflow + `num_bits` + """ + if num_bits % 8: + raise CompilerPanic("Type is not a modulo of 8") + + value_bits = num_bits - (1 if is_signed else 0) + if a >= 2 ** value_bits: + raise TypeCheckFailure("Value is too large and will always throw") + elif a < -(2 ** value_bits): + raise TypeCheckFailure("Value is too small and will always throw") + + a_is_negative = a < 0 + a = abs(a) # No longer need to know if it's signed or not + + if a in (0, 1): + raise CompilerPanic("Exponential operation is useless!") + + # NOTE: There is an edge case if `a` were left signed where the following + # operation would not work (`ln(a)` is undefined if `a <= 0`) + b = int(decimal.Decimal(value_bits) / (decimal.Decimal(a).ln() / decimal.Decimal(2).ln())) + if b <= 1: + return 1 # Value is assumed to be in range, therefore power of 1 is max + + # Do a bit of iteration to ensure we have the exact number + + # CMC 2022-05-06 (TODO we should be able to this with algebra + # instead of looping): + # a ** x == 2**value_bits + # x ln(a) = ln(2**value_bits) + # x = ln(2**value_bits) / ln(a) + + num_iterations = 0 + while a ** (b + 1) < 2 ** value_bits: + b += 1 + num_iterations += 1 + assert num_iterations < 10000 + while a ** b >= 2 ** value_bits: + b -= 1 + num_iterations += 1 + assert num_iterations < 10000 + + # Edge case: If a is negative and the values of a and b are such that: + # (a) ** (b + 1) == -(2 ** value_bits) + # we can actually squeak one more out of it because it's on the edge + if a_is_negative and (-a) ** (b + 1) == -(2 ** value_bits): # NOTE: a = abs(a) + return b + 1 + else: + return b # Exact + + +def calculate_largest_base(b: int, num_bits: int, is_signed: bool) -> int: + """ + For a given power `b`, compute the maximum base `a` that will not produce an + overflow in the equation `a ** b` + + Arguments + --------- + b : int + Power value for the equation `a ** b` + num_bits : int + The maximum number of bits that the resulting value must fit in + is_signed : bool + Is the operation being performed on signed integers? + + Returns + ------- + int + Largest possible value for `a` where the result does not overflow + `num_bits` + """ + if num_bits % 8: + raise CompilerPanic("Type is not a modulo of 8") + if b < 0: + raise TypeCheckFailure("Cannot calculate negative exponents") + + value_bits = num_bits - (1 if is_signed else 0) + if b > value_bits: + raise TypeCheckFailure("Value is too large and will always throw") + elif b < 2: + return 2 ** value_bits - 1 # Maximum value for type + + # CMC 2022-05-06 TODO we should be able to do this with algebra + # instead of looping): + # x ** b == 2**value_bits + # b ln(x) == ln(2**value_bits) + # ln(x) == ln(2**value_bits) / b + # x == exp( ln(2**value_bits) / b) + + # Estimate (up to ~39 digits precision required) + a = math.ceil(2 ** (decimal.Decimal(value_bits) / decimal.Decimal(b))) + # Do a bit of iteration to ensure we have the exact number + num_iterations = 0 + while (a + 1) ** b < 2 ** value_bits: + a += 1 + num_iterations += 1 + assert num_iterations < 10000 + while a ** b >= 2 ** value_bits: + a -= 1 + num_iterations += 1 + assert num_iterations < 10000 + return a + + def safe_add(x: IRnode, y: IRnode): # precondition: x.typ.typ == t.typ.typ @@ -152,4 +273,23 @@ def safe_mod(x: IRnode, y: IRnode): def safe_pow(x: IRnode, y: IRnode): - pass + num_info = x.typ._num_info + + if x.is_literal: + upper_bound = calculate_largest_power(x.value, num_info.bits, num_info.is_signed) + 1 + # for signed integers, this also prevents negative values + ok = ["lt", right, upper_bound] + ret = ["seq", ["assert", clamp], ["exp", left, right]] + + elif y.is_literal: + upper_bound = calculate_largest_base(y.value, num_info.bits, num_info.is_signed) + 1 + if is_signed: + ok = ["and", ["slt", left, upper_bound], ["sgt", left, -upper_bound]] + else: + ok = ["lt", left, upper_bound] + ret = ["seq", ["assert", ok], ["exp", left, right]] + else: + # `a ** b` where neither `a` or `b` are known + # TODO this is currently unreachable, once we implement a way to do it safely + # remove the check in `vyper/context/types/value/numeric.py` + return diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 9e857d680f1..09b2a6d25e1 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -59,113 +59,6 @@ } -def calculate_largest_power(a: int, num_bits: int, is_signed: bool) -> int: - """ - For a given base `a`, compute the maximum power `b` that will not - produce an overflow in the equation `a ** b` - - Arguments - --------- - a : int - Base value for the equation `a ** b` - num_bits : int - The maximum number of bits that the resulting value must fit in - is_signed : bool - Is the operation being performed on signed integers? - - Returns - ------- - int - Largest possible value for `b` where the result does not overflow - `num_bits` - """ - if num_bits % 8: - raise CompilerPanic("Type is not a modulo of 8") - - value_bits = num_bits - (1 if is_signed else 0) - if a >= 2 ** value_bits: - raise TypeCheckFailure("Value is too large and will always throw") - elif a < -(2 ** value_bits): - raise TypeCheckFailure("Value is too small and will always throw") - - a_is_negative = a < 0 - a = abs(a) # No longer need to know if it's signed or not - if a in (0, 1): - raise CompilerPanic("Exponential operation is useless!") - - # NOTE: There is an edge case if `a` were left signed where the following - # operation would not work (`ln(a)` is undefined if `a <= 0`) - b = int(decimal.Decimal(value_bits) / (decimal.Decimal(a).ln() / decimal.Decimal(2).ln())) - if b <= 1: - return 1 # Value is assumed to be in range, therefore power of 1 is max - - # Do a bit of iteration to ensure we have the exact number - num_iterations = 0 - while a ** (b + 1) < 2 ** value_bits: - b += 1 - num_iterations += 1 - assert num_iterations < 10000 - while a ** b >= 2 ** value_bits: - b -= 1 - num_iterations += 1 - assert num_iterations < 10000 - - # Edge case: If a is negative and the values of a and b are such that: - # (a) ** (b + 1) == -(2 ** value_bits) - # we can actually squeak one more out of it because it's on the edge - if a_is_negative and (-a) ** (b + 1) == -(2 ** value_bits): # NOTE: a = abs(a) - return b + 1 - else: - return b # Exact - - -def calculate_largest_base(b: int, num_bits: int, is_signed: bool) -> int: - """ - For a given power `b`, compute the maximum base `a` that will not produce an - overflow in the equation `a ** b` - - Arguments - --------- - b : int - Power value for the equation `a ** b` - num_bits : int - The maximum number of bits that the resulting value must fit in - is_signed : bool - Is the operation being performed on signed integers? - - Returns - ------- - int - Largest possible value for `a` where the result does not overflow - `num_bits` - """ - if num_bits % 8: - raise CompilerPanic("Type is not a modulo of 8") - if b < 0: - raise TypeCheckFailure("Cannot calculate negative exponents") - - value_bits = num_bits - (1 if is_signed else 0) - if b > value_bits: - raise TypeCheckFailure("Value is too large and will always throw") - elif b < 2: - return 2 ** value_bits - 1 # Maximum value for type - - # Estimate (up to ~39 digits precision required) - a = math.ceil(2 ** (decimal.Decimal(value_bits) / decimal.Decimal(b))) - # Do a bit of iteration to ensure we have the exact number - num_iterations = 0 - while (a + 1) ** b < 2 ** value_bits: - a += 1 - num_iterations += 1 - assert num_iterations < 10000 - while a ** b >= 2 ** value_bits: - a -= 1 - num_iterations += 1 - assert num_iterations < 10000 - - return a - - class Expr: # TODO: Once other refactors are made reevaluate all inline imports @@ -477,67 +370,19 @@ def parse_BinOp(self): out_typ = BaseType(ltyp) - ret = None - with left.cache_when_complex("x") as (b1, left), right.cache_when_complex("y") as (b2, y): + with left.cache_when_complex("x") as (b1, x), right.cache_when_complex("y") as (b2, y): if isinstance(self.expr.op, vy_ast.Add): - ret = arithmetic.safe_add(left, y) + ret = arithmetic.safe_add(x, y) elif isinstance(self.expr.op, vy_ast.Sub): - ret = arithmetic.safe_sub(left, y) + ret = arithmetic.safe_sub(x, y) elif isinstance(self.expr.op, vy_ast.Mult): - ret = arithmetic.safe_mul(left, y) + ret = arithmetic.safe_mul(x, y) elif isinstance(self.expr.op, vy_ast.Div): - ret = arithmetic.safe_div(left, y) + ret = arithmetic.safe_div(x, y) elif isinstance(self.expr.op, vy_ast.Mod): - ret = arithmetic.safe_mod(left, y) + ret = arithmetic.safe_mod(x, y) elif isinstance(self.expr.op, vy_ast.Pow): - # TODO: move this to arithmetic.py - - # TODO optimizer rule for special cases - if self.expr.left.get("value") == 1: - return IRnode.from_list([1], typ=out_typ) - if self.expr.left.get("value") == 0: - return IRnode.from_list(["iszero", right], typ=out_typ) - - if ltyp == "int128": - is_signed = True - num_bits = 128 - elif ltyp == "int256": - is_signed = True - num_bits = 256 - elif ltyp == "uint8": - is_signed = False - num_bits = 8 - else: - is_signed = False - num_bits = 256 - - if isinstance(self.expr.left, vy_ast.Int): - value = self.expr.left.value - upper_bound = calculate_largest_power(value, num_bits, is_signed) + 1 - # for signed integers, this also prevents negative values - clamp = ["lt", right, upper_bound] - ret = ["seq", ["assert", clamp], ["exp", left, right]] - - elif isinstance(self.expr.right, vy_ast.Int): - value = self.expr.right.value - upper_bound = calculate_largest_base(value, num_bits, is_signed) + 1 - if is_signed: - clamp = ["and", ["slt", left, upper_bound], ["sgt", left, -upper_bound]] - else: - clamp = ["lt", left, upper_bound] - ret = ["seq", ["assert", clamp], ["exp", left, right]] - else: - # `a ** b` where neither `a` or `b` are known - # TODO this is currently unreachable, once we implement a way to do it safely - # remove the check in `vyper/context/types/value/numeric.py` - return - - if ret is None: - op_str = self.expr.op._pretty - raise UnimplementedException( - f"Not implemented: {ltyp} {op_str} {rtyp}", self.expr.op - ) - + ret = arithmetic.safe_pow(x, y) return IRnode.from_list(b1.resolve(b2.resolve(ret)), typ=out_typ) def build_in_comparator(self):