Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New attr.cmp_using helper function #787

Merged
merged 12 commits into from
May 1, 2021
2 changes: 2 additions & 0 deletions src/attr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import partial

from . import converters, exceptions, filters, setters, validators
from ._cmp import cmp_using
from ._config import get_run_validators, set_run_validators
from ._funcs import asdict, assoc, astuple, evolve, has, resolve_types
from ._make import (
Expand Down Expand Up @@ -52,6 +53,7 @@
"attrib",
"attributes",
"attrs",
"cmp_using",
"converters",
"evolve",
"exceptions",
Expand Down
14 changes: 9 additions & 5 deletions src/attr/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ _OnSetAttrType = Callable[[Any, Attribute[Any], Any], Any]
_OnSetAttrArgType = Union[
_OnSetAttrType, List[_OnSetAttrType], setters._NoOpType
]
_FieldTransformer = Callable[[type, List[Attribute[Any]]], List[Attribute[Any]]]
_FieldTransformer = Callable[
[type, List[Attribute[Any]]], List[Attribute[Any]]
]
_CompareWithType = Callable[[Any, Any], bool]

# FIXME: in reality, if multiple validators are passed they must be in a list
# or tuple, but those are invariant and so would prevent subtypes of
# _ValidatorType from working when passed in a list or tuple.
Expand All @@ -64,7 +68,6 @@ NOTHING: object
# Work around mypy issue #4554 in the common case by using an overload.
if sys.version_info >= (3, 8):
from typing import Literal

@overload
def Factory(factory: Callable[[], _T]) -> _T: ...
@overload
Expand All @@ -77,6 +80,7 @@ if sys.version_info >= (3, 8):
factory: Callable[[], _T],
takes_self: Literal[False],
) -> _T: ...

else:
@overload
def Factory(factory: Callable[[], _T]) -> _T: ...
Expand All @@ -86,7 +90,6 @@ else:
takes_self: bool = ...,
) -> _T: ...


class Attribute(Generic[_T]):
name: str
default: Optional[_T]
Expand All @@ -102,7 +105,6 @@ class Attribute(Generic[_T]):
type: Optional[Type[_T]]
kw_only: bool
on_setattr: _OnSetAttrType

def evolve(self, **changes: Any) -> "Attribute[Any]": ...

# NOTE: We had several choices for the annotation to use for type arg:
Expand Down Expand Up @@ -429,7 +431,9 @@ def asdict(
filter: Optional[_FilterType[Any]] = ...,
dict_factory: Type[Mapping[Any, Any]] = ...,
retain_collection_types: bool = ...,
value_serializer: Optional[Callable[[type, Attribute[Any], Any], Any]] = ...,
value_serializer: Optional[
Callable[[type, Attribute[Any], Any], Any]
] = ...,
) -> Dict[str, Any]: ...

# TODO: add support for returning NamedTuple from the mypy plugin
Expand Down
150 changes: 150 additions & 0 deletions src/attr/_cmp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from __future__ import absolute_import, division, print_function

import functools

from ._compat import new_class
from ._make import _make_ne


_operation_names = {"eq": "==", "lt": "<", "le": "<=", "gt": ">", "ge": ">="}


def cmp_using(
eq=None,
lt=None,
le=None,
gt=None,
ge=None,
require_same_type=True,
class_name="Comparable",
):
"""
Utility function that creates a class with customized equality and
ordering methods.

The resulting class will have a full set of ordering methods if
at least one of ``{lt, le, gt, ge}`` and ``eq`` are provided.

:param Optional[callable] eq: `callable` used to evaluate equality
of two objects.
:param Optional[callable] lt: `callable` used to evaluate whether
one object is less than another object.
:param Optional[callable] le: `callable` used to evaluate whether
one object is less than or equal to another object.
:param Optional[callable] gt: `callable` used to evaluate whether
one object is greater than another object.
:param Optional[callable] ge: `callable` used to evaluate whether
one object is greater than or equal to another object.

:param bool require_same_type: When `True`, equality and ordering methods
will return `NotImplemented` if objects are not of the same type.

:param Optional[str] class_name: Name of class. Defaults to 'Comparable'.

.. versionadded:: 21.1.0
"""

body = {
"__slots__": ["value"],
"__init__": _make_init(),
"_requirements": [],
"_is_comparable_to": _is_comparable_to,
}

# Add operations.
num_order_fucntions = 0
botant marked this conversation as resolved.
Show resolved Hide resolved
has_eq_function = False

if eq is not None:
has_eq_function = True
body["__eq__"] = _make_operator("eq", eq)
body["__ne__"] = _make_ne()

if lt is not None:
num_order_fucntions += 1
body["__lt__"] = _make_operator("lt", lt)

if le is not None:
num_order_fucntions += 1
body["__le__"] = _make_operator("le", le)

if gt is not None:
num_order_fucntions += 1
body["__gt__"] = _make_operator("gt", gt)

if ge is not None:
num_order_fucntions += 1
body["__ge__"] = _make_operator("ge", ge)

type_ = new_class(class_name, (object,), {}, lambda ns: ns.update(body))

# Add same type requirement.
if require_same_type:
type_._requirements.append(_check_same_type)

# Add total ordering if at least one operation was defined.
if 0 < num_order_fucntions < 4:
if not has_eq_function:
# functools.total_ordering requires __eq__ to be defined,
# so raise early error here to keep a nice stack.
raise ValueError(
"eq must be define is order to complete ordering from "
"lt, le, gt, ge."
)
type_ = functools.total_ordering(type_)

return type_


def _make_init():
"""
Create __init__ method.
"""

def __init__(self, value):
"""
Initialize object with *value*.
"""
self.value = value

return __init__


def _make_operator(name, func):
"""
Create operator method.
"""

def method(self, other):
if not self._is_comparable_to(other):
return NotImplemented

result = func(self.value, other.value)
if result is NotImplemented:
return NotImplemented

return result

method.__name__ = "__%s__" % (name,)
method.__doc__ = "Return a %s b. Computed by attrs." % (
_operation_names[name],
)

return method


def _is_comparable_to(self, other):
"""
Check whether `other` is comparable to `self`.
"""
for func in self._requirements:
if not func(self, other):
return False
return True


def _check_same_type(self, other):
"""
Return True if *self* and *other* are of the same type, False otherwise.
"""
return other.value.__class__ is self.value.__class__
14 changes: 14 additions & 0 deletions src/attr/_cmp.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Type

from . import _CompareWithType


def cmp_using(
eq: Optional[_CompareWithType],
lt: Optional[_CompareWithType],
le: Optional[_CompareWithType],
gt: Optional[_CompareWithType],
ge: Optional[_CompareWithType],
require_same_type: bool,
class_name: str,
) -> Type: ...
Loading