Skip to content

Commit

Permalink
move pow to arithmetic.py
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-cooper committed May 6, 2022
1 parent 2d49d2f commit 379fac9
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 164 deletions.
2 changes: 1 addition & 1 deletion tests/fuzzing/test_exponents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from hypothesis import example, given, settings
from hypothesis import strategies as st

from vyper.codegen.expr import calculate_largest_base, calculate_largest_power
from vyper.codegen.arithmetic import calculate_largest_base, calculate_largest_power


@pytest.mark.fuzzing
Expand Down
142 changes: 141 additions & 1 deletion vyper/codegen/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,127 @@
from vyper.exceptions import CompilerPanic


def calculate_largest_power(a: int, num_bits: int, is_signed: bool) -> int:
"""
For a given base `a`, compute the maximum power `b` that will not
produce an overflow in the equation `a ** b`
Arguments
---------
a : int
Base value for the equation `a ** b`
num_bits : int
The maximum number of bits that the resulting value must fit in
is_signed : bool
Is the operation being performed on signed integers?
Returns
-------
int
Largest possible value for `b` where the result does not overflow
`num_bits`
"""
if num_bits % 8:
raise CompilerPanic("Type is not a modulo of 8")

value_bits = num_bits - (1 if is_signed else 0)
if a >= 2 ** value_bits:
raise TypeCheckFailure("Value is too large and will always throw")
elif a < -(2 ** value_bits):
raise TypeCheckFailure("Value is too small and will always throw")

a_is_negative = a < 0
a = abs(a) # No longer need to know if it's signed or not

if a in (0, 1):
raise CompilerPanic("Exponential operation is useless!")

# NOTE: There is an edge case if `a` were left signed where the following
# operation would not work (`ln(a)` is undefined if `a <= 0`)
b = int(decimal.Decimal(value_bits) / (decimal.Decimal(a).ln() / decimal.Decimal(2).ln()))
if b <= 1:
return 1 # Value is assumed to be in range, therefore power of 1 is max

# Do a bit of iteration to ensure we have the exact number

# CMC 2022-05-06 (TODO we should be able to this with algebra
# instead of looping):
# a ** x == 2**value_bits
# x ln(a) = ln(2**value_bits)
# x = ln(2**value_bits) / ln(a)

num_iterations = 0
while a ** (b + 1) < 2 ** value_bits:
b += 1
num_iterations += 1
assert num_iterations < 10000
while a ** b >= 2 ** value_bits:
b -= 1
num_iterations += 1
assert num_iterations < 10000

# Edge case: If a is negative and the values of a and b are such that:
# (a) ** (b + 1) == -(2 ** value_bits)
# we can actually squeak one more out of it because it's on the edge
if a_is_negative and (-a) ** (b + 1) == -(2 ** value_bits): # NOTE: a = abs(a)
return b + 1
else:
return b # Exact


def calculate_largest_base(b: int, num_bits: int, is_signed: bool) -> int:
"""
For a given power `b`, compute the maximum base `a` that will not produce an
overflow in the equation `a ** b`
Arguments
---------
b : int
Power value for the equation `a ** b`
num_bits : int
The maximum number of bits that the resulting value must fit in
is_signed : bool
Is the operation being performed on signed integers?
Returns
-------
int
Largest possible value for `a` where the result does not overflow
`num_bits`
"""
if num_bits % 8:
raise CompilerPanic("Type is not a modulo of 8")
if b < 0:
raise TypeCheckFailure("Cannot calculate negative exponents")

value_bits = num_bits - (1 if is_signed else 0)
if b > value_bits:
raise TypeCheckFailure("Value is too large and will always throw")
elif b < 2:
return 2 ** value_bits - 1 # Maximum value for type

# CMC 2022-05-06 TODO we should be able to do this with algebra
# instead of looping):
# x ** b == 2**value_bits
# b ln(x) == ln(2**value_bits)
# ln(x) == ln(2**value_bits) / b
# x == exp( ln(2**value_bits) / b)

# Estimate (up to ~39 digits precision required)
a = math.ceil(2 ** (decimal.Decimal(value_bits) / decimal.Decimal(b)))
# Do a bit of iteration to ensure we have the exact number
num_iterations = 0
while (a + 1) ** b < 2 ** value_bits:
a += 1
num_iterations += 1
assert num_iterations < 10000
while a ** b >= 2 ** value_bits:
a -= 1
num_iterations += 1
assert num_iterations < 10000
return a


def safe_add(x: IRnode, y: IRnode):
# precondition: x.typ.typ == t.typ.typ

Expand Down Expand Up @@ -152,4 +273,23 @@ def safe_mod(x: IRnode, y: IRnode):


def safe_pow(x: IRnode, y: IRnode):
pass
num_info = x.typ._num_info

if x.is_literal:
upper_bound = calculate_largest_power(x.value, num_info.bits, num_info.is_signed) + 1
# for signed integers, this also prevents negative values
ok = ["lt", right, upper_bound]
ret = ["seq", ["assert", clamp], ["exp", left, right]]

elif y.is_literal:
upper_bound = calculate_largest_base(y.value, num_info.bits, num_info.is_signed) + 1
if is_signed:
ok = ["and", ["slt", left, upper_bound], ["sgt", left, -upper_bound]]
else:
ok = ["lt", left, upper_bound]
ret = ["seq", ["assert", ok], ["exp", left, right]]
else:
# `a ** b` where neither `a` or `b` are known
# TODO this is currently unreachable, once we implement a way to do it safely
# remove the check in `vyper/context/types/value/numeric.py`
return
169 changes: 7 additions & 162 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,113 +59,6 @@
}


def calculate_largest_power(a: int, num_bits: int, is_signed: bool) -> int:
"""
For a given base `a`, compute the maximum power `b` that will not
produce an overflow in the equation `a ** b`
Arguments
---------
a : int
Base value for the equation `a ** b`
num_bits : int
The maximum number of bits that the resulting value must fit in
is_signed : bool
Is the operation being performed on signed integers?
Returns
-------
int
Largest possible value for `b` where the result does not overflow
`num_bits`
"""
if num_bits % 8:
raise CompilerPanic("Type is not a modulo of 8")

value_bits = num_bits - (1 if is_signed else 0)
if a >= 2 ** value_bits:
raise TypeCheckFailure("Value is too large and will always throw")
elif a < -(2 ** value_bits):
raise TypeCheckFailure("Value is too small and will always throw")

a_is_negative = a < 0
a = abs(a) # No longer need to know if it's signed or not
if a in (0, 1):
raise CompilerPanic("Exponential operation is useless!")

# NOTE: There is an edge case if `a` were left signed where the following
# operation would not work (`ln(a)` is undefined if `a <= 0`)
b = int(decimal.Decimal(value_bits) / (decimal.Decimal(a).ln() / decimal.Decimal(2).ln()))
if b <= 1:
return 1 # Value is assumed to be in range, therefore power of 1 is max

# Do a bit of iteration to ensure we have the exact number
num_iterations = 0
while a ** (b + 1) < 2 ** value_bits:
b += 1
num_iterations += 1
assert num_iterations < 10000
while a ** b >= 2 ** value_bits:
b -= 1
num_iterations += 1
assert num_iterations < 10000

# Edge case: If a is negative and the values of a and b are such that:
# (a) ** (b + 1) == -(2 ** value_bits)
# we can actually squeak one more out of it because it's on the edge
if a_is_negative and (-a) ** (b + 1) == -(2 ** value_bits): # NOTE: a = abs(a)
return b + 1
else:
return b # Exact


def calculate_largest_base(b: int, num_bits: int, is_signed: bool) -> int:
"""
For a given power `b`, compute the maximum base `a` that will not produce an
overflow in the equation `a ** b`
Arguments
---------
b : int
Power value for the equation `a ** b`
num_bits : int
The maximum number of bits that the resulting value must fit in
is_signed : bool
Is the operation being performed on signed integers?
Returns
-------
int
Largest possible value for `a` where the result does not overflow
`num_bits`
"""
if num_bits % 8:
raise CompilerPanic("Type is not a modulo of 8")
if b < 0:
raise TypeCheckFailure("Cannot calculate negative exponents")

value_bits = num_bits - (1 if is_signed else 0)
if b > value_bits:
raise TypeCheckFailure("Value is too large and will always throw")
elif b < 2:
return 2 ** value_bits - 1 # Maximum value for type

# Estimate (up to ~39 digits precision required)
a = math.ceil(2 ** (decimal.Decimal(value_bits) / decimal.Decimal(b)))
# Do a bit of iteration to ensure we have the exact number
num_iterations = 0
while (a + 1) ** b < 2 ** value_bits:
a += 1
num_iterations += 1
assert num_iterations < 10000
while a ** b >= 2 ** value_bits:
a -= 1
num_iterations += 1
assert num_iterations < 10000

return a


class Expr:
# TODO: Once other refactors are made reevaluate all inline imports

Expand Down Expand Up @@ -477,67 +370,19 @@ def parse_BinOp(self):

out_typ = BaseType(ltyp)

ret = None
with left.cache_when_complex("x") as (b1, left), right.cache_when_complex("y") as (b2, y):
with left.cache_when_complex("x") as (b1, x), right.cache_when_complex("y") as (b2, y):
if isinstance(self.expr.op, vy_ast.Add):
ret = arithmetic.safe_add(left, y)
ret = arithmetic.safe_add(x, y)
elif isinstance(self.expr.op, vy_ast.Sub):
ret = arithmetic.safe_sub(left, y)
ret = arithmetic.safe_sub(x, y)
elif isinstance(self.expr.op, vy_ast.Mult):
ret = arithmetic.safe_mul(left, y)
ret = arithmetic.safe_mul(x, y)
elif isinstance(self.expr.op, vy_ast.Div):
ret = arithmetic.safe_div(left, y)
ret = arithmetic.safe_div(x, y)
elif isinstance(self.expr.op, vy_ast.Mod):
ret = arithmetic.safe_mod(left, y)
ret = arithmetic.safe_mod(x, y)
elif isinstance(self.expr.op, vy_ast.Pow):
# TODO: move this to arithmetic.py

# TODO optimizer rule for special cases
if self.expr.left.get("value") == 1:
return IRnode.from_list([1], typ=out_typ)
if self.expr.left.get("value") == 0:
return IRnode.from_list(["iszero", right], typ=out_typ)

if ltyp == "int128":
is_signed = True
num_bits = 128
elif ltyp == "int256":
is_signed = True
num_bits = 256
elif ltyp == "uint8":
is_signed = False
num_bits = 8
else:
is_signed = False
num_bits = 256

if isinstance(self.expr.left, vy_ast.Int):
value = self.expr.left.value
upper_bound = calculate_largest_power(value, num_bits, is_signed) + 1
# for signed integers, this also prevents negative values
clamp = ["lt", right, upper_bound]
ret = ["seq", ["assert", clamp], ["exp", left, right]]

elif isinstance(self.expr.right, vy_ast.Int):
value = self.expr.right.value
upper_bound = calculate_largest_base(value, num_bits, is_signed) + 1
if is_signed:
clamp = ["and", ["slt", left, upper_bound], ["sgt", left, -upper_bound]]
else:
clamp = ["lt", left, upper_bound]
ret = ["seq", ["assert", clamp], ["exp", left, right]]
else:
# `a ** b` where neither `a` or `b` are known
# TODO this is currently unreachable, once we implement a way to do it safely
# remove the check in `vyper/context/types/value/numeric.py`
return

if ret is None:
op_str = self.expr.op._pretty
raise UnimplementedException(
f"Not implemented: {ltyp} {op_str} {rtyp}", self.expr.op
)

ret = arithmetic.safe_pow(x, y)
return IRnode.from_list(b1.resolve(b2.resolve(ret)), typ=out_typ)

def build_in_comparator(self):
Expand Down

0 comments on commit 379fac9

Please sign in to comment.