Skip to content

Commit

Permalink
Merge pull request #1791 from iamdefinitelyahuman/unary-fixes
Browse files Browse the repository at this point in the history
Unary operations on literals
  • Loading branch information
fubuloubu authored Dec 30, 2019
2 parents d4b28c1 + 0c57a05 commit 47cd159
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 15 deletions.
41 changes: 38 additions & 3 deletions tests/parser/functions/test_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ def negate(a: int128) -> int128:
assert c.negate(val) == -val


decimal_divisor = Decimal('1e10')
min_decimal = (-2**127 + 1) / decimal_divisor
max_decimal = (2**127 - 1) / decimal_divisor
min_decimal = -2**127 + 1
max_decimal = 2**127 - 1
@pytest.mark.parametrize("val", [min_decimal, 0, max_decimal])
def test_unary_sub_decimal_pass(get_contract, val):
code = """@public
Expand All @@ -48,3 +47,39 @@ def negate(a: decimal) -> decimal:
"""
c = get_contract(code)
assert c.negate(val) == -val


def test_negation_decimal(get_contract):
code = """
a: constant(decimal) = 170141183460469231731687303715884105726.9999999999
b: constant(decimal) = -170141183460469231731687303715884105726.9999999999
@public
def foo() -> decimal:
return -a
@public
def bar() -> decimal:
return -b
"""

c = get_contract(code)
assert c.foo() == Decimal("-170141183460469231731687303715884105726.9999999999")
assert c.bar() == Decimal("170141183460469231731687303715884105726.9999999999")


def test_negation_int128(get_contract):
code = """
a: constant(int128) = -2**127
@public
def foo() -> int128:
return -2**127
@public
def bar() -> int128:
return -(a+1)
"""
c = get_contract(code)
assert c.foo() == -2**127
assert c.bar() == 2**127-1
21 changes: 10 additions & 11 deletions vyper/parser/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ def build_in_comparator(self):
if left.typ != right.typ.subtype:
raise TypeMismatchException(
f"{left.typ} cannot be in a list of {right.typ.subtype}",
self.expr,
)

result_placeholder = self.context.new_placeholder(BaseType('bool'))
Expand Down Expand Up @@ -935,21 +936,19 @@ def unary_operations(self):
self.expr,
)
elif isinstance(self.expr.op, ast.USub):
# Must be a signed integer
if not is_numeric_type(operand.typ) or operand.typ.typ.lower().startswith('u'):
if not is_numeric_type(operand.typ):
raise TypeMismatchException(
f"Unsupported type for negation: {operand.typ}",
operand,
self.expr,
)

if operand.typ.is_literal and 'int' in operand.typ.typ:
num = ast.Num(n=0 - operand.value)
num.source_code = self.expr.source_code
num.lineno = self.expr.lineno
num.col_offset = self.expr.col_offset
num.end_lineno = self.expr.end_lineno
num.end_col_offset = self.expr.end_col_offset
return Expr.parse_value_expr(num, self.context)
if operand.typ.is_literal:
typ = "decimal" if operand.typ.typ == "decimal" else "int128"
return LLLnode.from_list(
0-operand.value,
typ=BaseType(typ, unit=operand.typ.unit, is_literal=True),
pos=getpos(self.expr),
)

# Clamp on minimum integer value as we cannot negate that value
# (all other integer values are fine)
Expand Down
2 changes: 1 addition & 1 deletion vyper/parser/parser_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_number_as_fraction(expr, context):

if exponent < -10:
raise InvalidLiteralException(
"`decimal` literal cannot have more than 10 decimal places: {literal}",
f"`decimal` literal cannot have more than 10 decimal places: {literal}",
expr
)

Expand Down

0 comments on commit 47cd159

Please sign in to comment.