Skip to content

Commit

Permalink
fix bugs, handle more cases
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-cooper committed Mar 8, 2022
1 parent 591fd64 commit a9f8d99
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 60 deletions.
112 changes: 63 additions & 49 deletions vyper/builtin_functions/convert.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,47 @@
import math
import warnings
from decimal import Decimal
import decimal
import functools

from vyper import ast as vy_ast
from vyper.codegen.expr import Expr
from vyper.codegen.core import (
LLLnode,
add_ofst,
bytes_data_ptr,
clamp_basetype,
get_bytearray_length,
bytes_data_ptr,
getpos,
int_clamp,
load_op,
load_word,
promote_signed_int,
sar,
shl,
shr,
wordsize,
)
from vyper.codegen.expr import Expr
from vyper.codegen.types import (
DYNAMIC_ARRAY_OVERHEAD,
is_integer_type,
is_bytes_m_type,
parse_bytes_m_info,
is_decimal_type,
parse_decimal_info,
parse_integer_typeinfo,
BaseType,
ByteArrayType,
ByteArrayLike,
ByteArrayType,
StringType,
INTEGER_TYPES,
BYTES_M_TYPES,
is_base_type,
DECIMAL_TYPES,
is_bytes_m_type,
is_decimal_type,
is_integer_type,
parse_bytes_m_info,
parse_decimal_info,
parse_integer_typeinfo,
)
from vyper.evm.opcodes import version_check
from vyper.exceptions import InvalidLiteral, StructureException, TypeMismatch
from vyper.utils import DECIMAL_DIVISOR, MemoryPositions, SizeLimits
from vyper.semantics.types.abstract import NumericAbstractType, BytesAbstractType, BytesMAbstractType, UnsignedIntegerAbstractType
from vyper.semantics.types import AddressDefinition, BoolDefinition, BytesArrayDefinition, StringDefinition
from vyper.semantics.types import (
AddressDefinition,
BoolDefinition,
BytesArrayDefinition,
StringDefinition,
)
from vyper.semantics.types.abstract import (
BytesAbstractType,
NumericAbstractType,
UnsignedIntegerAbstractType,
)
from vyper.utils import DECIMAL_DIVISOR, SizeLimits, int_bounds


def _FAIL(ityp, otyp, pos=None):
Expand All @@ -56,7 +58,9 @@ def g(expr, arg, out_typ):
if not ok:
_FAIL(expr._metadata["type"], out_typ, expr)
return f(expr, arg, out_typ)

return g

return decorator


Expand All @@ -65,7 +69,7 @@ def _byte_array_to_num(arg, out_type, clamp):
Generate LLL which takes a <32 byte array as input and returns a number
"""
len_ = get_bytearray_length(arg)
val = unwrap_location(bytes_data_ptr(arg))
val = load_word(bytes_data_ptr(arg))

# converting a bytestring to a number:
# bytestring is right-padded with zeroes, int is left-padded.
Expand Down Expand Up @@ -121,11 +125,12 @@ def to_bool(expr, arg, out_typ):


def _literal_int(expr, out_typ):
int_info = parse_integer_typeinfo(out_typ.typ)
val = int(expr.value) # should work for Int, Decimal, Hex
(lo, hi) = int_bounds(int_info.is_signed, int_info.bits)
if not (lo <= val <= hi):
raise InvalidLiteral(f"Number out of range", expr)
return LLLnode.from_list(arg, typ=out_typ, is_literal=True)
raise InvalidLiteral("Number out of range", expr)
return LLLnode.from_list(val, typ=out_typ, is_literal=True)


@_input_types(NumericAbstractType, BytesAbstractType, BoolDefinition)
Expand All @@ -143,7 +148,7 @@ def to_int(expr, arg, out_typ):
arg = _byte_array_to_num(arg, out_typ, clamp=False)

if is_decimal_type(arg.typ):
info = parse_decimal_typeinfo(arg.typ.typ)
info = parse_decimal_info(arg.typ.typ)
arg = _fixed_to_int(arg, out_typ, decimals=info.decimals)

if is_bytes_m_type(arg.typ):
Expand All @@ -152,13 +157,13 @@ def to_int(expr, arg, out_typ):

# NOTE bytesM to intN is like casting to bytesJ then intN
# (where J = N/8)
if m_bits < 256: # TODO optimizer rule for this
arg = shr(256 - m_bits, arg)
arg = shr(256 - m_bits, arg)
is_downcast = m_bits > int_info.bits # do we need to clamp?
if is_downcast:
arg = LLLnode.from_list(arg, typ=out_typ)
arg = clamp_basetype(arg)
# TODO we need signextend for signed ints.
if int_info.is_signed:
arg = promote_signed_int(arg, int_info.bits)

if is_integer_type(arg.typ):
arg_info = parse_integer_typeinfo(arg.typ.typ)
Expand All @@ -177,7 +182,7 @@ def to_int(expr, arg, out_typ):
# NOTE: sar works for both ways, including uint256 <-> int256
tmp.append(["assert", ["iszero", sar(int_info.bits, arg)]])
tmp.append(arg)
arg = b.resolve(tmp)
arg = tmp

elif arg_info.bits > int_info.bits:
# cast to out_type so clamp_basetype works
Expand All @@ -189,15 +194,15 @@ def to_int(expr, arg, out_typ):


@_input_types(NumericAbstractType, BoolDefinition)
def to_decimal(expr, arg, _out_typ):
def to_decimal(expr, arg, out_typ):
if isinstance(expr, vy_ast.Constant):
val = Decimal(expr.value) # should work for Int, Decimal, Hex
(lo, hi) = (MIN_DECIMAL, MAX_DECIMAL)
if not (lo <= self.expr.val <= hi):
raise InvalidLiteral(f"Number out of range", expr)
val = decimal.Decimal(expr.value) # should work for Int, Decimal, Hex
(lo, hi) = (SizeLimits.MIN_DECIMAL, SizeLimits.MAX_DECIMAL)
if not (lo <= expr.val <= hi):
raise InvalidLiteral("Number out of range", expr)

return LLLnode.from_list(
val * DECIMAL_DIVISOR,
int(val * DECIMAL_DIVISOR),
typ=BaseType(out_typ, is_literal=True),
)

Expand All @@ -217,20 +222,30 @@ def to_bytes_m(expr, arg, out_typ):
_check_bytes(expr, arg, out_typ, max_bytes_allowed=m)

if isinstance(arg.typ, ByteArrayType):
load = load_op(arg.location)
bytes_val = [load, bytes_data_ptr(arg)]
bytes_val = load_word(bytes_data_ptr(arg))

# zero out any dirty bytes (which can happen in the last
# word of a bytearray)
len_ = get_bytearray_length(arg)
num_zero_bits = LLLnode.from_list(["mul", ["sub", 32, len_], 8])
with num_zero_bits.cache_when_complex("bits") as (b2, num_zero_bits):
with num_zero_bits.cache_when_complex("bits") as (b, num_zero_bits):
ret = shl(num_zero_bits, shr(num_zero_bits, bytes_val))
ret = b1.resolve(b2.resolve(ret))
ret = b.resolve(ret)

else:
# TODO shl for int types.
ret = arg
elif is_integer_type(arg.typ) or is_base_type(arg.typ, "address"):
m_bits = m * 8
if is_integer_type(arg.typ):
int_bits = parse_integer_typeinfo(arg.typ.typ).bits
else: # address
int_bits = 160

if m_bits > int_bits:
raise _FAIL(expr, arg.typ, out_typ)

# no special handling for signed ints needed
# (downcasting is disallowed so we don't need to deal with
# upper `1` bits)
ret = shl(m_bits - int_bits, arg)

return LLLnode.from_list(ret, typ=out_typ)

Expand All @@ -245,15 +260,14 @@ def to_address(expr, arg, out_typ):
arg = _byte_array_to_num(arg, out_typ, clamp=False)
should_clamp = False

if is_bytes_m_type(arg.typ):
elif is_bytes_m_type(arg.typ):
m = parse_bytes_m_info(arg.typ.typ)
m_bits = m * 8
if m_bits < 256:
arg = shr(256 - m_bits, arg)
arg = shr(256 - m_bits, arg)

should_clamp = m_bits > 160

if is_integer_type(arg.typ):
elif is_integer_type(arg.typ):
int_info = parse_integer_typeinfo(arg.typ.typ)
should_clamp = int_info.bits > 160 or int_info.is_signed

Expand Down Expand Up @@ -306,7 +320,7 @@ def convert(expr, context):
elif isinstance(out_typ, StringType):
ret = to_string(expr, arg, out_typ)
else:
raise StructureException(f"Conversion to {output_type} is invalid.", expr)
raise StructureException(f"Conversion to {out_typ} is invalid.", expr)

ret = b.resolve(ret)

Expand Down
3 changes: 2 additions & 1 deletion vyper/builtin_functions/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
getpos,
lll_tuple_from_args,
load_op,
promote_signed_int,
unwrap_location,
)
from vyper.codegen.expr import Expr
Expand Down Expand Up @@ -1725,7 +1726,7 @@ def build_LLL(self, expr, args, kwargs, context):
# wrap for ops which could under/overflow
if int_info.is_signed:
# e.g. int128 -> (signextend 15 (add x y))
ret = ["signextend", int_info.bits // 8 - 1, ret]
ret = promote_signed_int(ret, int_info.bits)
else:
# e.g. uint8 -> (mod (add x y) 256)
# TODO mod_bound could be a really large literal
Expand Down
22 changes: 19 additions & 3 deletions vyper/codegen/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from decimal import Context, Decimal, setcontext
from decimal import Context, setcontext

from vyper import ast as vy_ast
from vyper.codegen.lll_node import Encoding, LLLnode
Expand All @@ -20,7 +20,6 @@
from vyper.exceptions import (
CompilerPanic,
DecimalOverrideException,
InvalidLiteral,
StructureException,
TypeCheckFailure,
TypeMismatch,
Expand Down Expand Up @@ -555,6 +554,7 @@ def get_element_ptr(parent, key, pos, array_bounds_check=True):
return b.resolve(ret)


# TODO phase this out - make private and use load_word instead
def load_op(location):
if location == "memory":
return "mload"
Expand All @@ -572,6 +572,7 @@ def load_op(location):
raise CompilerPanic(f"unreachable {location}") # pragma: notest


# TODO phase this out - make private and use store_word instead
def store_op(location):
if location == "memory":
return "mstore"
Expand All @@ -582,10 +583,18 @@ def store_op(location):
raise CompilerPanic(f"unreachable {location}") # pragma: notest


def load_word(ptr: LLLnode) -> LLLnode:
return LLLnode.from_list([load_op(ptr.location), ptr])


def store_word(ptr: LLLnode, val: LLLnode) -> LLLnode:
return LLLnode.from_list([store_op(ptr.location), ptr, val])


# Unwrap location
def unwrap_location(orig):
if orig.location in ("memory", "storage", "calldata", "data", "immutables"):
return LLLnode.from_list([load_op(orig.location), orig], typ=orig.typ)
return LLLnode.from_list(load_word(orig), typ=orig.typ)
else:
# CMC 20210909 TODO double check if this branch can be removed
if orig.value == "~empty":
Expand Down Expand Up @@ -1017,3 +1026,10 @@ def int_clamp(lll_node, bits, signed=False):
ret = ["with", "val", lll_node, ["seq", assertion, "val"]]

return LLLnode.from_list(ret, annotation=f"int_clamp {lll_node.typ}")


# e.g. for int8, promote 255 to -1
def promote_signed_int(x, bits):
assert bits % 8 == 0
ret = ["signextend", bits // 8 - 1, x]
return LLLnode.from_list(ret, annotation=f"promote int{bits}")
8 changes: 4 additions & 4 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import decimal
import math
from decimal import Decimal

from vyper import ast as vy_ast
from vyper.codegen import external_call, self_call
Expand Down Expand Up @@ -99,7 +99,7 @@ def calculate_largest_power(a: int, num_bits: int, is_signed: bool) -> int:

# NOTE: There is an edge case if `a` were left signed where the following
# operation would not work (`ln(a)` is undefined if `a <= 0`)
b = int(Decimal(value_bits) / (Decimal(a).ln() / Decimal(2).ln()))
b = int(decimal.Decimal(value_bits) / (decimal.Decimal(a).ln() / decimal.Decimal(2).ln()))
if b <= 1:
return 1 # Value is assumed to be in range, therefore power of 1 is max

Expand Down Expand Up @@ -155,7 +155,7 @@ def calculate_largest_base(b: int, num_bits: int, is_signed: bool) -> int:
return 2 ** value_bits - 1 # Maximum value for type

# Estimate (up to ~39 digits precision required)
a = math.ceil(2 ** (Decimal(value_bits) / Decimal(b)))
a = math.ceil(2 ** (decimal.Decimal(value_bits) / decimal.Decimal(b)))
# Do a bit of iteration to ensure we have the exact number
num_iterations = 0
while (a + 1) ** b < 2 ** value_bits:
Expand Down Expand Up @@ -223,7 +223,7 @@ def parse_Decimal(self):
assert SizeLimits.MIN_DECIMAL <= self.value <= SizeLimits.MAX_DECIMAL

return LLLnode.from_list(
num * DECIMAL_DIVISOR,
int(self.value * DECIMAL_DIVISOR),
typ=BaseType("decimal", is_literal=True),
pos=getpos(self.expr),
)
Expand Down
5 changes: 3 additions & 2 deletions vyper/codegen/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from vyper.exceptions import ArgumentException, CompilerPanic, InvalidType
from vyper.utils import ceil32


# Available base types
UNSIGNED_INTEGER_TYPES = {"uint8", "uint256"}
SIGNED_INTEGER_TYPES = {"int128", "int256"}
Expand All @@ -34,6 +33,7 @@

BASE_TYPES = INTEGER_TYPES | BYTES_M_TYPES | DECIMAL_TYPES | {"bool", "address"}


# Data structure for a type
class NodeType(abc.ABC):
def __eq__(self, other: Any) -> bool:
Expand Down Expand Up @@ -77,6 +77,7 @@ def storage_size_in_words(self) -> int:
# helper functions for handling old base types which are just strings
# in the future these can be reified with new type system


@dataclass
class IntegerTypeInfo:
is_signed: bool
Expand Down Expand Up @@ -112,7 +113,7 @@ def is_bytes_m_type(t: "NodeType") -> bool:


def parse_bytes_m_info(typename: str) -> int:
return int(typename[len("bytes"):])
return int(typename[len("bytes") :])


def is_decimal_type(t: "NodeType") -> bool:
Expand Down
6 changes: 6 additions & 0 deletions vyper/lll/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def apply_general_optimizations(node: LLLnode) -> LLLnode:
argz = []
value = ceil32(t.value)

# x >> 0 == x << 0 == x
elif node.value in ("shl", "shr", "sar") and get_int_at(argz, 0) == 0:
value = argz[1].value
annotation = argz[1].annotation
argz = []

elif node.value == "add" and get_int_at(argz, 0) == 0:
value = argz[1].value
annotation = argz[1].annotation
Expand Down
Loading

0 comments on commit a9f8d99

Please sign in to comment.