Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minmax signing #1790

Merged
merged 2 commits into from
Dec 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions tests/parser/functions/test_minmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,55 @@ def foo() -> uint256:
lambda: get_contract_with_gas_estimation(code_2),
TypeMismatchException
)


def test_unsigned(get_contract_with_gas_estimation):
code = """
@public
def foo1() -> uint256:
return min(0, 2**255)

@public
def foo2() -> uint256:
return min(2**255, 0)

@public
def foo3() -> uint256:
return max(0, 2**255)

@public
def foo4() -> uint256:
return max(2**255, 0)
fubuloubu marked this conversation as resolved.
Show resolved Hide resolved
"""

c = get_contract_with_gas_estimation(code)
assert c.foo1() == 0
assert c.foo2() == 0
assert c.foo3() == 2**255
assert c.foo4() == 2**255


def test_signed(get_contract_with_gas_estimation):
code = """
@public
def foo1() -> int128:
return min(MIN_INT128, MAX_INT128)

@public
def foo2() -> int128:
return min(MAX_INT128, MIN_INT128)

@public
def foo3() -> int128:
return max(MIN_INT128, MAX_INT128)

@public
def foo4() -> int128:
return max(MAX_INT128, MIN_INT128)
"""

c = get_contract_with_gas_estimation(code)
assert c.foo1() == -2**127
assert c.foo2() == -2**127
assert c.foo3() == 2**127-1
assert c.foo4() == 2**127-1
13 changes: 6 additions & 7 deletions vyper/functions/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,15 +1162,15 @@ def create_forwarder_to(expr, args, kwargs, context):

@signature(('int128', 'decimal', 'uint256'), ('int128', 'decimal', 'uint256'))
def _min(expr, args, kwargs, context):
return minmax(expr, args, kwargs, context, True)
return minmax(expr, args, kwargs, context, 'gt')


@signature(('int128', 'decimal', 'uint256'), ('int128', 'decimal', 'uint256'))
def _max(expr, args, kwargs, context):
return minmax(expr, args, kwargs, context, False)
return minmax(expr, args, kwargs, context, 'lt')


def minmax(expr, args, kwargs, context, is_min):
def minmax(expr, args, kwargs, context, comparator):
def _can_compare_with_uint256(operand):
if operand.typ.typ == 'uint256':
return True
Expand All @@ -1181,11 +1181,10 @@ def _can_compare_with_uint256(operand):
left, right = args[0], args[1]
if not are_units_compatible(left.typ, right.typ) and not are_units_compatible(right.typ, left.typ): # noqa: E501
raise TypeMismatchException("Units must be compatible", expr)
if left.typ.typ == 'uint256':
comparator = 'gt' if is_min else 'lt'
else:
comparator = 'sgt' if is_min else 'slt'
if left.typ.typ == right.typ.typ:
if left.typ.typ != 'uint256':
# if comparing like types that are not uint256, use SLT or SGT
comparator = f's{comparator}'
o = ['if', [comparator, '_l', '_r'], '_r', '_l']
otyp = left.typ
otyp.is_literal = False
Expand Down