Skip to content

Commit

Permalink
fix: ensure all comparable types return NotImplemented when a compa…
Browse files Browse the repository at this point in the history
…rison is not possible
  • Loading branch information
daniel-makerx committed Aug 27, 2024
1 parent 3cbdbd2 commit b055fa6
Show file tree
Hide file tree
Showing 12 changed files with 1,749 additions and 80 deletions.
103 changes: 35 additions & 68 deletions src/_algopy_testing/arc4.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dataclasses
import decimal
import functools
import types
import typing

Expand Down Expand Up @@ -169,7 +170,8 @@ def bytes(self) -> algopy.Bytes:
def __eq__(self, other: object) -> bool:
if isinstance(other, _ABIEncoded):
return self._type_info == other._type_info and self.bytes == other.bytes
return NotImplemented
else:
return NotImplemented

def __hash__(self) -> int:
return hash(self.bytes)
Expand Down Expand Up @@ -234,7 +236,11 @@ def __radd__(self, other: String | str) -> String:
return String(as_string(other) + self.native)

def __eq__(self, other: String | str) -> bool: # type: ignore[override]
return self.native == as_string(other)
try:
other_string = as_string(other)
except TypeError:
return NotImplemented
return self.native == other_string

def __bool__(self) -> bool:
"""Returns `True` if length is not zero."""
Expand Down Expand Up @@ -294,47 +300,12 @@ def __init__(
bytes_value = int_to_bytes(value, self._type_info.max_bytes_len)
self._value = as_bytes(bytes_value)

def __eq__( # type: ignore[override]
self,
other: UIntN[_TBitSize] | BigUIntN[_TBitSize] | algopy.UInt64 | algopy.BigUInt | int,
) -> bool:
raise NotImplementedError

def __ne__( # type: ignore[override]
self,
other: UIntN[_TBitSize] | BigUIntN[_TBitSize] | algopy.UInt64 | algopy.BigUInt | int,
) -> bool:
raise NotImplementedError

def __le__(
self,
other: UIntN[_TBitSize] | BigUIntN[_TBitSize] | algopy.UInt64 | algopy.BigUInt | int,
) -> bool:
raise NotImplementedError

def __lt__(
self,
other: UIntN[_TBitSize] | BigUIntN[_TBitSize] | algopy.UInt64 | algopy.BigUInt | int,
) -> bool:
raise NotImplementedError

def __ge__(
self,
other: UIntN[_TBitSize] | BigUIntN[_TBitSize] | algopy.UInt64 | algopy.BigUInt | int,
) -> bool:
raise NotImplementedError

def __gt__(
self,
other: UIntN[_TBitSize] | BigUIntN[_TBitSize] | algopy.UInt64 | algopy.BigUInt | int,
) -> bool:
raise NotImplementedError

def __bool__(self) -> bool:
"""Returns `True` if not equal to zero."""
raise NotImplementedError


@functools.total_ordering
class UIntN(_UIntN, typing.Generic[_TBitSize]): # type: ignore[type-arg]
"""An ARC4 UInt consisting of the number of bits specified.
Expand All @@ -349,22 +320,18 @@ def native(self) -> algopy.UInt64:
return algopy.UInt64(int.from_bytes(self._value))

def __eq__(self, other: object) -> bool:
return as_int64(self.native) == as_int(other, max=None)

def __ne__(self, other: object) -> bool:
return as_int64(self.native) != as_int(other, max=None)

def __le__(self, other: object) -> bool:
return as_int64(self.native) <= as_int(other, max=None)
try:
other_int = as_int64(other)
except (TypeError, ValueError):
return NotImplemented
return as_int64(self.native) == other_int

def __lt__(self, other: object) -> bool:
return as_int64(self.native) < as_int(other, max=None)

def __ge__(self, other: object) -> bool:
return as_int64(self.native) >= as_int(other, max=None)

def __gt__(self, other: object) -> bool:
return as_int64(self.native) > as_int(other, max=None)
try:
other_int = as_int64(other)
except (TypeError, ValueError):
return NotImplemented
return as_int64(self.native) < other_int

def __bool__(self) -> bool:
return bool(self.native)
Expand All @@ -376,6 +343,7 @@ def __repr__(self) -> str:
return _arc4_repr(self)


@functools.total_ordering
class BigUIntN(_UIntN, typing.Generic[_TBitSize]): # type: ignore[type-arg]
"""An ARC4 UInt consisting of the number of bits specified.
Expand All @@ -390,22 +358,18 @@ def native(self) -> algopy.BigUInt:
return algopy.BigUInt.from_bytes(self._value)

def __eq__(self, other: object) -> bool:
return as_int512(self.native) == as_int(other, max=None)

def __ne__(self, other: object) -> bool:
return as_int512(self.native) != as_int(other, max=None)

def __le__(self, other: object) -> bool:
return as_int512(self.native) <= as_int(other, max=None)
try:
other_int = as_int512(other)
except (TypeError, ValueError):
return NotImplemented
return as_int512(self.native) == other_int

def __lt__(self, other: object) -> bool:
return as_int512(self.native) < as_int(other, max=None)

def __ge__(self, other: object) -> bool:
return as_int512(self.native) >= as_int(other, max=None)

def __gt__(self, other: object) -> bool:
return as_int512(self.native) > as_int(other, max=None)
try:
other_int = as_int512(other)
except (TypeError, ValueError):
return NotImplemented
return as_int512(self.native) < other_int

def __bool__(self) -> bool:
return bool(self.native)
Expand Down Expand Up @@ -749,8 +713,11 @@ def __eq__(self, other: Address | Account | str) -> bool: # type: ignore[overri
`Account` or `str`"""
if isinstance(other, Address | Account):
return self.bytes == other.bytes
other_bytes: bytes = algosdk.encoding.decode_address(other)
return self.bytes == other_bytes
elif isinstance(other, str):
other_bytes: bytes = algosdk.encoding.decode_address(other)
return self.bytes == other_bytes
else:
return NotImplemented

def __str__(self) -> str:
return str(self.native)
Expand Down
13 changes: 6 additions & 7 deletions src/_algopy_testing/models/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,12 @@ def __repr__(self) -> str:
return self.public_key

def __eq__(self, other: object) -> bool:
match other:
case Account() as other_acc:
return self._public_key == other_acc._public_key
case str() as other_str:
return self.public_key == other_str
case _:
return NotImplemented
if isinstance(other, Account):
return self._public_key == other._public_key
elif isinstance(other, str):
return self.public_key == other
else:
return NotImplemented

def __bool__(self) -> bool:
return bool(self._public_key) and self._public_key != algosdk.encoding.decode_address(
Expand Down
6 changes: 5 additions & 1 deletion src/_algopy_testing/models/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@ def __getattr__(self, name: str) -> typing.Any:
def __eq__(self, other: object) -> bool:
if isinstance(other, Application):
return self._id == other._id
return self._id == as_int64(other)
# can compare Applications to int types only (not uint64)
elif isinstance(other, int):
return self._id == other
else:
return NotImplemented

def __bool__(self) -> bool:
return self._id != 0
Expand Down
5 changes: 4 additions & 1 deletion src/_algopy_testing/models/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ def __getattr__(self, name: str) -> typing.Any:
def __eq__(self, other: object) -> bool:
if isinstance(other, Asset):
return self.id == other.id
return self.id == other
elif isinstance(other, int):
return self.id == other
else:
return NotImplemented

def __bool__(self) -> bool:
return self.id != 0
Expand Down
6 changes: 5 additions & 1 deletion src/_algopy_testing/primitives/biguint.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ def __str__(self) -> str:
# however will raise a TypeError if compared to something that is not a numeric value
# as this would be a compile error when compiled to TEAL
def __eq__(self, other: object) -> bool:
return as_int512(self) == as_int512(other)
try:
other_uint = as_int512(other)
except TypeError:
return NotImplemented
return as_int512(self) == other_uint

def __lt__(self, other: BigUInt | UInt64 | int) -> bool:
return as_int512(self) < as_int512(other)
Expand Down
6 changes: 5 additions & 1 deletion src/_algopy_testing/primitives/bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ def __getitem__(
return Bytes(self.value[slice(int_index, int_index + 1)])

def __eq__(self, other: object) -> bool:
return self.value == as_bytes(other)
try:
other_bytes = as_bytes(other)
except TypeError:
return NotImplemented
return self.value == other_bytes

def __and__(self, other: bytes | Bytes) -> Bytes:
return self._operate_bitwise(other, "and_")
Expand Down
6 changes: 5 additions & 1 deletion src/_algopy_testing/primitives/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ def __bool__(self) -> bool:
return bool(self.value)

def __eq__(self, other: object) -> bool:
return self.value == as_string(other)
try:
other_string = as_string(other)
except TypeError:
return NotImplemented
return self.value == other_string

def __contains__(self, item: object) -> bool:
return as_string(item) in self.value
Expand Down
Empty file.
Loading

0 comments on commit b055fa6

Please sign in to comment.