diff --git a/ibis/common/bases.py b/ibis/common/bases.py new file mode 100644 index 000000000000..20f219ed7da5 --- /dev/null +++ b/ibis/common/bases.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from typing import TYPE_CHECKING, Any, Mapping +from weakref import WeakValueDictionary + +from ibis.common.caching import WeakCache +from ibis.common.collections import FrozenDict + +if TYPE_CHECKING: + from typing_extensions import Self + + +class BaseMeta(ABCMeta): + """Base metaclass for many of the ibis core classes. + + This metaclass enforces the subclasses to define a `__slots__` attribute and + provides a `__create__` classmethod that can be used to change the class + instantiation behavior. + """ + + __slots__ = () + + def __new__(metacls, clsname, bases, dct, **kwargs): + # enforce slot definitions + dct.setdefault("__slots__", ()) + return super().__new__(metacls, clsname, bases, dct, **kwargs) + + def __call__(cls, *args, **kwargs): + """Create a new instance of the class. + + The subclass may override the `__create__` classmethod to change the + instantiation behavior. This is similar to overriding the `__new__` + method, but without conditionally calling the `__init__` based on the + return type. + + Parameters + ---------- + args : tuple + Positional arguments eventually passed to the `__init__` method. + kwargs : dict + Keyword arguments eventually passed to the `__init__` method. + + Returns + ------- + The newly created instance of the class. No extra initialization + """ + return cls.__create__(*args, **kwargs) + + +class Base(metaclass=BaseMeta): + """Base class for many of the ibis core classes. + + This class enforces the subclasses to define a `__slots__` attribute and + provides a `__create__` classmethod that can be used to change the class + instantiation behavior. Also enables weak references for the subclasses. + """ + + __slots__ = ("__weakref__",) + __create__ = classmethod(type.__call__) # type: ignore + + +class Immutable(Base): + """Prohibit attribute assignment on the instance.""" + + def __copy__(self): + return self + + def __deepcopy__(self, memo): + return self + + def __setattr__(self, name: str, _: Any) -> None: + raise AttributeError( + f"Attribute {name!r} cannot be assigned to immutable instance of " + f"type {type(self)}" + ) + + +class Singleton(Base): + """Cache instances of the class based on instantiation arguments.""" + + __instances__: Mapping[Any, Self] = WeakValueDictionary() + + @classmethod + def __create__(cls, *args, **kwargs): + key = (cls, args, FrozenDict(kwargs)) + try: + return cls.__instances__[key] + except KeyError: + instance = super().__create__(*args, **kwargs) + cls.__instances__[key] = instance + return instance + + +class Final(Base): + """Prohibit subclassing.""" + + def __init_subclass__(cls, **kwargs): + cls.__init_subclass__ = cls.__prohibit_inheritance__ + + @classmethod + def __prohibit_inheritance__(cls, **kwargs): + raise TypeError(f"Cannot inherit from final class {cls}") + + +class Comparable(Base): + """Enable quick equality comparisons. + + The subclasses must implement the `__equals__` method that returns a boolean + value indicating whether the two instances are equal. This method is called + only if the two instances are of the same type and the result is cached for + future comparisons. + + Since the class holds a global cache of comparison results, it is important + to make sure that the instances are not kept alive longer than necessary. + This is done automatically by using weak references for the compared objects. + """ + + __cache__ = WeakCache() + + def __eq__(self, other) -> bool: + try: + return self.__cached_equals__(other) + except TypeError: + return NotImplemented + + @abstractmethod + def __equals__(self, other) -> bool: + ... + + def __cached_equals__(self, other) -> bool: + if self is other: + return True + + # type comparison should be cheap + if type(self) is not type(other): + return False + + # reduce space required for commutative operation + if id(self) < id(other): + key = (self, other) + else: + key = (other, self) + + try: + result = self.__cache__[key] + except KeyError: + result = self.__equals__(other) + self.__cache__[key] = result + + return result + + +class Slotted(Base): + """A lightweight alternative to `ibis.common.grounds.Concrete`. + + This class is used to create immutable dataclasses with slots and a precomputed + hash value for quicker dictionary lookups. + """ + + __slots__ = ("__precomputed_hash__",) + + def __init__(self, **kwargs) -> Self: + for name, value in kwargs.items(): + object.__setattr__(self, name, value) + hashvalue = hash(tuple(kwargs.values())) + object.__setattr__(self, "__precomputed_hash__", hashvalue) + + def __eq__(self, other) -> bool: + if self is other: + return True + if type(self) is not type(other): + return NotImplemented + return all(getattr(self, n) == getattr(other, n) for n in self.__slots__) + + def __hash__(self) -> int: + return self.__precomputed_hash__ + + def __setattr__(self, name, value) -> None: + raise AttributeError("Can't set attributes on an immutable instance") + + def __repr__(self): + fields = {k: getattr(self, k) for k in self.__slots__} + fieldstring = ", ".join(f"{k}={v!r}" for k, v in fields.items()) + return f"{self.__class__.__name__}({fieldstring})" + + def __rich_repr__(self): + for name in self.__slots__: + yield name, getattr(self, name) diff --git a/ibis/common/grounds.py b/ibis/common/grounds.py index e26912e11f32..6854b2e259c8 100644 --- a/ibis/common/grounds.py +++ b/ibis/common/grounds.py @@ -1,17 +1,14 @@ from __future__ import annotations import contextlib -from abc import ABCMeta, abstractmethod from copy import copy from typing import ( Any, ClassVar, - Mapping, Tuple, Union, get_origin, ) -from weakref import WeakValueDictionary from typing_extensions import Self, dataclass_transform @@ -23,29 +20,19 @@ Signature, attribute, ) -from ibis.common.caching import WeakCache -from ibis.common.collections import FrozenDict +from ibis.common.bases import ( # noqa: F401 + Base, + BaseMeta, + Comparable, + Final, + Immutable, + Singleton, +) +from ibis.common.collections import FrozenDict # noqa: TCH001 from ibis.common.patterns import Pattern from ibis.common.typing import evaluate_annotations -class BaseMeta(ABCMeta): - __slots__ = () - - def __new__(metacls, clsname, bases, dct, **kwargs): - # enforce slot definitions - dct.setdefault("__slots__", ()) - return super().__new__(metacls, clsname, bases, dct, **kwargs) - - def __call__(cls, *args, **kwargs): - return cls.__create__(*args, **kwargs) - - -class Base(metaclass=BaseMeta): - __slots__ = ("__weakref__",) - __create__ = classmethod(type.__call__) # type: ignore - - class AnnotableMeta(BaseMeta): """Metaclass to turn class annotations into a validatable function signature.""" @@ -187,79 +174,6 @@ def copy(self, **overrides: Any) -> Annotable: return this -class Immutable(Base): - def __copy__(self): - return self - - def __deepcopy__(self, memo): - return self - - def __setattr__(self, name: str, _: Any) -> None: - raise AttributeError( - f"Attribute {name!r} cannot be assigned to immutable instance of " - f"type {type(self)}" - ) - - -class Singleton(Base): - __instances__: Mapping[Any, Self] = WeakValueDictionary() - - @classmethod - def __create__(cls, *args, **kwargs): - key = (cls, args, FrozenDict(kwargs)) - try: - return cls.__instances__[key] - except KeyError: - instance = super().__create__(*args, **kwargs) - cls.__instances__[key] = instance - return instance - - -class Final(Base): - def __init_subclass__(cls, **kwargs): - cls.__init_subclass__ = cls.__prohibit_inheritance__ - - @classmethod - def __prohibit_inheritance__(cls, **kwargs): - raise TypeError(f"Cannot inherit from final class {cls}") - - -class Comparable(Base): - __cache__ = WeakCache() - - def __eq__(self, other) -> bool: - try: - return self.__cached_equals__(other) - except TypeError: - return NotImplemented - - @abstractmethod - def __equals__(self, other) -> bool: - ... - - def __cached_equals__(self, other) -> bool: - if self is other: - return True - - # type comparison should be cheap - if type(self) is not type(other): - return False - - # reduce space required for commutative operation - if id(self) < id(other): - key = (self, other) - else: - key = (other, self) - - try: - result = self.__cache__[key] - except KeyError: - result = self.__equals__(other) - self.__cache__[key] = result - - return result - - class Concrete(Immutable, Comparable, Annotable): """Opinionated base class for immutable data classes.""" diff --git a/ibis/common/patterns.py b/ibis/common/patterns.py index 55736d58a1fe..6d2217302369 100644 --- a/ibis/common/patterns.py +++ b/ibis/common/patterns.py @@ -8,7 +8,7 @@ from collections.abc import Callable, Hashable, Mapping, Sequence from enum import Enum from inspect import Parameter -from itertools import chain, zip_longest +from itertools import chain from typing import Any as AnyType from typing import ( ForwardRef, @@ -23,6 +23,7 @@ import toolz from typing_extensions import Annotated, GenericMeta, Self, get_args, get_origin +from ibis.common.bases import Singleton, Slotted from ibis.common.collections import RewindableIterator, frozendict from ibis.common.dispatch import lazy_singledispatch from ibis.common.typing import Sentinel, get_bound_typevars, get_type_params @@ -306,45 +307,6 @@ def __rmatmul__(self, name: str) -> Capture: return Capture(name, self) -class _Slotted: - """A lightweight alternative to `ibis.common.grounds.Concrete`. - - This class is used to create immutable dataclasses with slots and a precomputed - hash value for quicker dictionary lookups. - """ - - __slots__ = ("__precomputed_hash__",) - - def __init__(self, *args) -> Self: - for name, value in zip_longest(self.__slots__, args): - object.__setattr__(self, name, value) - object.__setattr__(self, "__precomputed_hash__", hash(args)) - - def __eq__(self, other) -> bool: - if self is other: - return True - if type(self) is not type(other): - return NotImplemented - return all( - getattr(self, name) == getattr(other, name) for name in self.__slots__ - ) - - def __hash__(self) -> int: - return self.__precomputed_hash__ - - def __setattr__(self, name, value) -> None: - raise AttributeError("Can't set attributes on an immutable instance") - - def __repr__(self): - fields = {k: getattr(self, k) for k in self.__slots__} - fieldstring = ", ".join(f"{k}={v!r}" for k, v in fields.items()) - return f"{self.__class__.__name__}({fieldstring})" - - def __rich_repr__(self): - for name in self.__slots__: - yield name, getattr(self, name) - - class Builder(Hashable): """A builder is a function that takes a context and returns a new object. @@ -410,7 +372,7 @@ def builder(obj): return Just(obj) -class Variable(_Slotted, Builder): +class Variable(Slotted, Builder): """Retrieve a value from the context. Parameters @@ -421,6 +383,9 @@ class Variable(_Slotted, Builder): __slots__ = ("name",) + def __init__(self, name): + super().__init__(name=name) + def make(self, context): return context[self] @@ -431,7 +396,7 @@ def __getitem__(self, name): return Call(operator.itemgetter(name), self) -class Just(_Slotted, Builder): +class Just(Slotted, Builder): """Construct exactly the given value. Parameters @@ -442,11 +407,15 @@ class Just(_Slotted, Builder): __slots__ = ("value",) + def __init__(self, value): + assert not isinstance(value, (Pattern, Builder)) + super().__init__(value=value) + def make(self, context): return self.value -class Factory(_Slotted, Builder): +class Factory(Slotted, Builder): """Construct a value by calling a function. The function is called with two positional arguments: @@ -463,12 +432,16 @@ class Factory(_Slotted, Builder): __slots__ = ("func",) + def __init__(self, func): + assert callable(func) + super().__init__(func=func) + def make(self, context): value = context[_] return self.func(value, context) -class Call(_Slotted, Builder): +class Call(Slotted, Builder): """Pattern that calls a function with the given arguments. Both positional and keyword arguments are coerced into patterns. @@ -486,9 +459,10 @@ class Call(_Slotted, Builder): __slots__ = ("func", "args", "kwargs") def __init__(self, func, *args, **kwargs): + assert callable(func) args = tuple(map(builder, args)) kwargs = frozendict({k: builder(v) for k, v in kwargs.items()}) - super().__init__(func, args, kwargs) + super().__init__(func=func, args=args, kwargs=kwargs) def make(self, context): args = tuple(arg.make(context) for arg in self.args) @@ -529,29 +503,21 @@ def namespace(cls, module) -> Namespace: _ = Variable("_") -class Matcher(_Slotted, Pattern): - __slots__ = () - - -class Always(Matcher): +class Always(Slotted, Singleton, Pattern): """Pattern that matches everything.""" - __slots__ = () - def match(self, value, context): return value -class Never(Matcher): +class Never(Slotted, Singleton, Pattern): """Pattern that matches nothing.""" - __slots__ = () - def match(self, value, context): return NoMatch -class Is(Matcher): +class Is(Slotted, Pattern): """Pattern that matches a value against a reference value. Parameters @@ -569,11 +535,9 @@ def match(self, value, context): return NoMatch -class Any(Matcher): +class Any(Slotted, Singleton, Pattern): """Pattern that accepts any value, basically a no-op.""" - __slots__ = () - def match(self, value, context): return value @@ -581,7 +545,7 @@ def match(self, value, context): _any = Any() -class Capture(Matcher): +class Capture(Slotted, Pattern): """Pattern that captures a value in the context. Parameters @@ -595,7 +559,7 @@ class Capture(Matcher): __slots__ = ("key", "pattern") def __init__(self, key, pat=_any): - super().__init__(key, pattern(pat)) + super().__init__(key=key, pattern=pattern(pat)) def match(self, value, context): value = self.pattern.match(value, context) @@ -605,7 +569,7 @@ def match(self, value, context): return value -class Replace(Matcher): +class Replace(Slotted, Pattern): """Pattern that replaces a value with the output of another pattern. Parameters @@ -616,13 +580,13 @@ class Replace(Matcher): The pattern to use as a replacement. """ - __slots__ = ("matcher", "builder") + __slots__ = ("pattern", "builder") def __init__(self, matcher, replacer): - super().__init__(pattern(matcher), builder(replacer)) + super().__init__(pattern=pattern(matcher), builder=builder(replacer)) def match(self, value, context): - value = self.matcher.match(value, context) + value = self.pattern.match(value, context) if value is NoMatch: return NoMatch # use the `_` reserved variable to record the value being replaced @@ -631,7 +595,7 @@ def match(self, value, context): return self.builder.make(context) -class Check(Matcher): +class Check(Slotted, Pattern): """Pattern that checks a value against a predicate. Parameters @@ -642,6 +606,10 @@ class Check(Matcher): __slots__ = ("predicate",) + def __init__(self, predicate): + assert callable(predicate) + super().__init__(predicate=predicate) + def match(self, value, context): if self.predicate(value): return value @@ -649,7 +617,7 @@ def match(self, value, context): return NoMatch -class Function(Matcher): +class Function(Slotted, Pattern): """Pattern that applies a function to the value. Parameters @@ -660,6 +628,10 @@ class Function(Matcher): __slots__ = ("func",) + def __init__(self, func): + assert callable(func) + super().__init__(func=func) + def match(self, value, context): return self.func(value, context) @@ -702,7 +674,7 @@ def __getattr__(self, name: str) -> Pattern: return self.pattern(getattr(self.module, name)) -class Apply(Matcher): +class Apply(Slotted, Pattern): """Pattern that applies a function to the value. The function must accept a single argument. @@ -722,6 +694,10 @@ class Apply(Matcher): __slots__ = ("func",) + def __init__(self, func): + assert callable(func) + super().__init__(func=func) + def match(self, value, context): return self.func(value) @@ -730,7 +706,7 @@ def __call__(self, *args, **kwargs): return Call(self.func, *args, **kwargs) -class EqualTo(Matcher): +class EqualTo(Slotted, Pattern): """Pattern that checks a value equals to the given value. Parameters @@ -741,6 +717,9 @@ class EqualTo(Matcher): __slots__ = ("value",) + def __init__(self, value): + super().__init__(value=value) + def match(self, value, context): if value == self.value: return value @@ -748,7 +727,7 @@ def match(self, value, context): return NoMatch -class Option(Matcher): +class Option(Slotted, Pattern): """Pattern that matches `None` or a value that passes the inner validator. Parameters @@ -759,8 +738,8 @@ class Option(Matcher): __slots__ = ("pattern", "default") - def __init__(self, pattern, default=None): - super().__init__(pattern, default) + def __init__(self, pat, default=None): + super().__init__(pattern=pattern(pat), default=default) def match(self, value, context): if value is None: @@ -772,11 +751,14 @@ def match(self, value, context): return self.pattern.match(value, context) -class TypeOf(Matcher): +class TypeOf(Slotted, Pattern): """Pattern that matches a value that is of a given type.""" __slots__ = ("type",) + def __init__(self, typ): + super().__init__(type=typ) + def match(self, value, context): if type(value) is self.type: return value @@ -784,7 +766,7 @@ def match(self, value, context): return NoMatch -class SubclassOf(Matcher): +class SubclassOf(Slotted, Pattern): """Pattern that matches a value that is a subclass of a given type. Parameters @@ -795,6 +777,9 @@ class SubclassOf(Matcher): __slots__ = ("type",) + def __init__(self, typ): + super().__init__(type=typ) + def match(self, value, context): if issubclass(value, self.type): return value @@ -802,7 +787,7 @@ def match(self, value, context): return NoMatch -class InstanceOf(Matcher): +class InstanceOf(Slotted, Singleton, Pattern): """Pattern that matches a value that is an instance of a given type. Parameters @@ -813,6 +798,9 @@ class InstanceOf(Matcher): __slots__ = ("type",) + def __init__(self, typ): + super().__init__(type=typ) + def match(self, value, context): if isinstance(value, self.type): return value @@ -823,7 +811,7 @@ def __call__(self, *args, **kwargs): return Object(self.type, *args, **kwargs) -class GenericInstanceOf(Matcher): +class GenericInstanceOf(Slotted, Pattern): """Pattern that matches a value that is an instance of a given generic type. Parameters @@ -851,35 +839,35 @@ class GenericInstanceOf(Matcher): >>> assert p.match(MyNumber(1), {}) is NoMatch """ - __slots__ = ("origin", "field_patterns") + __slots__ = ("origin", "fields") def __init__(self, typ): origin = get_origin(typ) typevars = get_bound_typevars(typ) - field_patterns = {} + fields = {} for var, (attr, type_) in typevars.items(): if not var.__covariant__: raise TypeError( f"Typevar {var} is not covariant, cannot use it in a GenericInstanceOf" ) - field_patterns[attr] = Pattern.from_typehint(type_, allow_coercion=False) + fields[attr] = Pattern.from_typehint(type_, allow_coercion=False) - super().__init__(origin, frozendict(field_patterns)) + super().__init__(origin=origin, fields=frozendict(fields)) def match(self, value, context): if not isinstance(value, self.origin): return NoMatch - for field, pattern in self.field_patterns.items(): - attr = getattr(value, field) + for name, pattern in self.fields.items(): + attr = getattr(value, name) if pattern.match(attr, context) is NoMatch: return NoMatch return value -class LazyInstanceOf(Matcher): +class LazyInstanceOf(Slotted, Pattern): """A version of `InstanceOf` that accepts qualnames instead of imported classes. Useful for delaying imports. @@ -896,7 +884,7 @@ def __init__(self, types): types = promote_tuple(types) check = lazy_singledispatch(lambda x: False) check.register(types, lambda x: True) - super().__init__(promote_tuple(types), check) + super().__init__(types=types, check=check) def match(self, value, context): if self.check(value): @@ -906,7 +894,7 @@ def match(self, value, context): # TODO(kszucs): to support As[int] or CoercedTo[int] syntax -class CoercedTo(Matcher): +class CoercedTo(Slotted, Pattern): """Force a value to have a particular Python type. If a Coercible subclass is passed, the `__coerce__` method will be used to @@ -927,6 +915,10 @@ def __new__(cls, target): else: return Apply(target) + def __init__(self, target): + assert isinstance(target, type) + super().__init__(target=target) + def match(self, value, context): try: value = self.target.__coerce__(value) @@ -945,7 +937,7 @@ def __repr__(self): As = CoercedTo -class GenericCoercedTo(Matcher): +class GenericCoercedTo(Slotted, Pattern): """Force a value to have a particular generic Python type. Parameters @@ -987,12 +979,10 @@ class GenericCoercedTo(Matcher): __slots__ = ("origin", "params", "checker") def __init__(self, target): - # TODO(kszucs): when constructing the checker we shouldn't allow - # coercions, only type checks origin = get_origin(target) checker = GenericInstanceOf(target) params = frozendict(get_type_params(target)) - super().__init__(origin, params, checker) + super().__init__(origin=origin, params=params, checker=checker) def match(self, value, context): try: @@ -1006,7 +996,7 @@ def match(self, value, context): return value -class Not(Matcher): +class Not(Slotted, Pattern): """Pattern that matches a value that does not match a given pattern. Parameters @@ -1018,7 +1008,7 @@ class Not(Matcher): __slots__ = ("pattern",) def __init__(self, inner): - super().__init__(pattern(inner)) + super().__init__(pattern=pattern(inner)) def match(self, value, context): if self.pattern.match(value, context) is NoMatch: @@ -1027,7 +1017,7 @@ def match(self, value, context): return NoMatch -class AnyOf(Matcher): +class AnyOf(Slotted, Pattern): """Pattern that if any of the given patterns match. Parameters @@ -1039,8 +1029,9 @@ class AnyOf(Matcher): __slots__ = ("patterns",) - def __init__(self, *patterns): - super().__init__(patterns) + def __init__(self, *pats): + patterns = tuple(map(pattern, pats)) + super().__init__(patterns=patterns) def match(self, value, context): for pattern in self.patterns: @@ -1050,7 +1041,7 @@ def match(self, value, context): return NoMatch -class AllOf(Matcher): +class AllOf(Slotted, Pattern): """Pattern that matches if all of the given patterns match. Parameters @@ -1063,8 +1054,9 @@ class AllOf(Matcher): __slots__ = ("patterns",) - def __init__(self, *patterns): - super().__init__(patterns) + def __init__(self, *pats): + patterns = tuple(map(pattern, pats)) + super().__init__(patterns=patterns) def match(self, value, context): for pattern in self.patterns: @@ -1074,7 +1066,7 @@ def match(self, value, context): return value -class Length(Matcher): +class Length(Slotted, Pattern): """Pattern that matches if the length of a value is within a given range. Parameters @@ -1101,7 +1093,7 @@ def __init__( raise ValueError("Can't specify both exactly and at_least/at_most") at_least = exactly at_most = exactly - super().__init__(at_least, at_most) + super().__init__(at_least=at_least, at_most=at_most) def match(self, value, context): length = len(value) @@ -1112,7 +1104,7 @@ def match(self, value, context): return value -class Contains(Matcher): +class Contains(Slotted, Pattern): """Pattern that matches if a value contains a given value. Parameters @@ -1123,6 +1115,9 @@ class Contains(Matcher): __slots__ = ("needle",) + def __init__(self, needle): + super().__init__(needle=needle) + def match(self, value, context): if self.needle in value: return value @@ -1130,7 +1125,7 @@ def match(self, value, context): return NoMatch -class IsIn(Matcher): +class IsIn(Slotted, Pattern): """Pattern that matches if a value is in a given set. Parameters @@ -1142,7 +1137,7 @@ class IsIn(Matcher): __slots__ = ("haystack",) def __init__(self, haystack): - super().__init__(frozenset(haystack)) + super().__init__(haystack=frozenset(haystack)) def match(self, value, context): if value in self.haystack: @@ -1154,7 +1149,7 @@ def match(self, value, context): In = IsIn -class SequenceOf(Matcher): +class SequenceOf(Slotted, Pattern): """Pattern that matches if all of the items in a sequence match a given pattern. Parameters @@ -1171,7 +1166,7 @@ class SequenceOf(Matcher): The maximum length of the sequence. """ - __slots__ = ("item_pattern", "type_pattern", "length_pattern") + __slots__ = ("item", "type", "length") def __init__( self, @@ -1181,10 +1176,10 @@ def __init__( at_least: Optional[int] = None, at_most: Optional[int] = None, ): - item_pattern = pattern(item) - type_pattern = CoercedTo(type) - length_pattern = Length(at_least=at_least, at_most=at_most) - super().__init__(item_pattern, type_pattern, length_pattern) + item = pattern(item) + type = CoercedTo(type) + length = Length(at_least=at_least, at_most=at_most) + super().__init__(item=item, type=type, length=length) def match(self, values, context): if not is_iterable(values): @@ -1192,19 +1187,19 @@ def match(self, values, context): result = [] for value in values: - value = self.item_pattern.match(value, context) + value = self.item.match(value, context) if value is NoMatch: return NoMatch result.append(value) - result = self.type_pattern.match(result, context) + result = self.type.match(result, context) if result is NoMatch: return NoMatch - return self.length_pattern.match(result, context) + return self.length.match(result, context) -class TupleOf(Matcher): +class TupleOf(Slotted, Pattern): """Pattern that matches if the respective items in a tuple match the given patterns. Parameters @@ -1213,7 +1208,7 @@ class TupleOf(Matcher): The patterns to match the respective items in the tuple. """ - __slots__ = ("field_patterns",) + __slots__ = ("fields",) def __new__(cls, fields): if isinstance(fields, tuple): @@ -1221,15 +1216,19 @@ def __new__(cls, fields): else: return SequenceOf(fields, tuple) + def __init__(self, fields): + fields = tuple(map(pattern, fields)) + super().__init__(fields=fields) + def match(self, values, context): if not is_iterable(values): return NoMatch - if len(values) != len(self.field_patterns): + if len(values) != len(self.fields): return NoMatch result = [] - for pattern, value in zip(self.field_patterns, values): + for pattern, value in zip(self.fields, values): value = pattern.match(value, context) if value is NoMatch: return NoMatch @@ -1238,7 +1237,7 @@ def match(self, values, context): return tuple(result) -class MappingOf(Matcher): +class MappingOf(Slotted, Pattern): """Pattern that matches if all of the keys and values match the given patterns. Parameters @@ -1251,10 +1250,10 @@ class MappingOf(Matcher): The type to coerce the mapping to. Defaults to dict. """ - __slots__ = ("key_pattern", "value_pattern", "type_pattern") + __slots__ = ("key", "value", "type") def __init__(self, key: Pattern, value: Pattern, type: type = dict): - super().__init__(pattern(key), pattern(value), CoercedTo(type)) + super().__init__(key=pattern(key), value=pattern(value), type=CoercedTo(type)) def match(self, value, context): if not isinstance(value, Mapping): @@ -1262,27 +1261,28 @@ def match(self, value, context): result = {} for k, v in value.items(): - if (k := self.key_pattern.match(k, context)) is NoMatch: + if (k := self.key.match(k, context)) is NoMatch: return NoMatch - if (v := self.value_pattern.match(v, context)) is NoMatch: + if (v := self.value.match(v, context)) is NoMatch: return NoMatch result[k] = v - result = self.type_pattern.match(result, context) + result = self.type.match(result, context) if result is NoMatch: return NoMatch return result -class Attrs(Matcher): - __slots__ = ("patterns",) +class Attrs(Slotted, Pattern): + __slots__ = ("fields",) - def __init__(self, **patterns): - super().__init__(frozendict(toolz.valmap(pattern, patterns))) + def __init__(self, **fields): + fields = frozendict(toolz.valmap(pattern, fields)) + super().__init__(fields=fields) def match(self, value, context): - for attr, pattern in self.patterns.items(): + for attr, pattern in self.fields.items(): if not hasattr(value, attr): return NoMatch @@ -1293,7 +1293,7 @@ def match(self, value, context): return value -class Object(Matcher): +class Object(Slotted, Pattern): """Pattern that matches if the object has the given attributes and they match the given patterns. The type must conform the structural pattern matching protocol, e.g. it must have a @@ -1321,7 +1321,7 @@ def __init__(self, type, *args, **kwargs): type = pattern(type) args = tuple(map(pattern, args)) kwargs = frozendict(toolz.valmap(pattern, kwargs)) - super().__init__(type, args, kwargs) + super().__init__(type=type, args=args, kwargs=kwargs) def match(self, value, context): if self.type.match(value, context) is NoMatch: @@ -1356,11 +1356,11 @@ def namespace(cls, module): return Namespace(InstanceOf, module) -class Node(Matcher): +class Node(Slotted, Pattern): __slots__ = ("type", "each_arg") def __init__(self, type, each_arg): - super().__init__(pattern(type), pattern(each_arg)) + super().__init__(type=pattern(type), each_arg=pattern(each_arg)) def match(self, value, context): if self.type.match(value, context) is NoMatch: @@ -1382,11 +1382,11 @@ def match(self, value, context): return value -class CallableWith(Matcher): - __slots__ = ("arg_patterns", "return_pattern") +class CallableWith(Slotted, Pattern): + __slots__ = ("args", "return_") - def __init__(self, args, return_=None): - super().__init__(tuple(args), return_ or _any) + def __init__(self, args, return_=_any): + super().__init__(args=tuple(args), return_=return_) def match(self, value, context): from ibis.common.annotations import annotated @@ -1394,7 +1394,7 @@ def match(self, value, context): if not callable(value): return NoMatch - fn = annotated(self.arg_patterns, self.return_pattern, value) + fn = annotated(self.args, self.return_, value) has_varargs = False positional, keyword_only = [], [] @@ -1410,17 +1410,17 @@ def match(self, value, context): raise MatchError( "Callable has mandatory keyword-only arguments which cannot be specified" ) - elif len(positional) > len(self.arg_patterns): + elif len(positional) > len(self.args): # Callable has more positional arguments than expected") return NoMatch - elif len(positional) < len(self.arg_patterns) and not has_varargs: + elif len(positional) < len(self.args) and not has_varargs: # Callable has less positional arguments than expected") return NoMatch else: return fn -class PatternSequence(Matcher): +class PatternSequence(Slotted, Pattern): # TODO(kszucs): add a length optimization to not even try to match if the # length of the sequence is lower than the length of the pattern sequence @@ -1432,11 +1432,7 @@ def __init__(self, patterns): ] following_patterns = chain(current_patterns[1:], [Not(_any)]) pattern_window = tuple(zip(current_patterns, following_patterns)) - super().__init__(pattern_window) - - @property - def first_pattern(self): - return self.pattern_window[0][0] + super().__init__(pattern_window=pattern_window) def match(self, value, context): it = RewindableIterator(value) @@ -1460,9 +1456,10 @@ def match(self, value, context): if isinstance(current, (SequenceOf, PatternSequence)): if isinstance(following, SequenceOf): - following = following.item_pattern + following = following.item elif isinstance(following, PatternSequence): - following = following.first_pattern + # first pattern to match from the pattern window + following = following.pattern_window[0][0] matches = [] while True: @@ -1499,30 +1496,30 @@ def match(self, value, context): return result -class PatternMapping(Matcher): - __slots__ = ("keys_pattern", "values_pattern") +class PatternMapping(Slotted, Pattern): + __slots__ = ("keys", "values") def __init__(self, patterns): - keys_pattern = PatternSequence(list(map(pattern, patterns.keys()))) - values_pattern = PatternSequence(list(map(pattern, patterns.values()))) - super().__init__(keys_pattern, values_pattern) + keys = PatternSequence(list(map(pattern, patterns.keys()))) + values = PatternSequence(list(map(pattern, patterns.values()))) + super().__init__(keys=keys, values=values) def match(self, value, context): if not isinstance(value, Mapping): return NoMatch keys = value.keys() - if (keys := self.keys_pattern.match(keys, context)) is NoMatch: + if (keys := self.keys.match(keys, context)) is NoMatch: return NoMatch values = value.values() - if (values := self.values_pattern.match(values, context)) is NoMatch: + if (values := self.values.match(values, context)) is NoMatch: return NoMatch return dict(zip(keys, values)) -class Between(Matcher): +class Between(Slotted, Pattern): """Match a value between two bounds. Parameters @@ -1536,7 +1533,7 @@ class Between(Matcher): __slots__ = ("lower", "upper") def __init__(self, lower: float = -math.inf, upper: float = math.inf): - super().__init__(lower, upper) + super().__init__(lower=lower, upper=upper) def match(self, value, context): if self.lower <= value <= self.upper: @@ -1644,16 +1641,16 @@ def match( return NoMatch if result is NoMatch else result -class Topmost(Matcher): +class Topmost(Slotted, Pattern): """Traverse the value tree topmost first and match the first value that matches.""" - __slots__ = ("searcher", "filter") + __slots__ = ("pattern", "filter") def __init__(self, searcher, filter=None): - super().__init__(pattern(searcher), filter) + super().__init__(pattern=pattern(searcher), filter=filter) def match(self, value, context): - result = self.searcher.match(value, context) + result = self.pattern.match(value, context) if result is not NoMatch: return result @@ -1665,14 +1662,14 @@ def match(self, value, context): return NoMatch -class Innermost(Matcher): +class Innermost(Slotted, Pattern): # matches items in the innermost layer first, but all matches belong to the same layer """Traverse the value tree innermost first and match the first value that matches.""" - __slots__ = ("searcher", "filter") + __slots__ = ("pattern", "filter") def __init__(self, searcher, filter=None): - super().__init__(pattern(searcher), filter) + super().__init__(pattern=pattern(searcher), filter=filter) def match(self, value, context): for child in value.__children__(self.filter): @@ -1680,7 +1677,7 @@ def match(self, value, context): if result is not NoMatch: return result - return self.searcher.match(value, context) + return self.pattern.match(value, context) IsTruish = Check(lambda x: bool(x)) diff --git a/ibis/common/tests/test_bases.py b/ibis/common/tests/test_bases.py new file mode 100644 index 000000000000..dc567b5f7d3c --- /dev/null +++ b/ibis/common/tests/test_bases.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +import copy +import weakref + +import pytest + +from ibis.common.caching import WeakCache +from ibis.common.collections import frozendict +from ibis.common.grounds import Base, Comparable, Final, Immutable, Singleton + + +def test_bases_are_based_on_base(): + assert issubclass(Comparable, Base) + assert issubclass(Final, Base) + assert issubclass(Immutable, Base) + assert issubclass(Singleton, Base) + + +def test_immutable(): + class Foo(Immutable): + __slots__ = ("a", "b") + + def __init__(self, a, b): + object.__setattr__(self, "a", a) + object.__setattr__(self, "b", b) + + foo = Foo(1, 2) + assert foo.a == 1 + assert foo.b == 2 + with pytest.raises(AttributeError): + foo.a = 2 + with pytest.raises(AttributeError): + foo.b = 3 + + assert copy.copy(foo) is foo + assert copy.deepcopy(foo) is foo + + +class Node(Comparable): + # override the default cache object + __cache__ = WeakCache() + __slots__ = ("name",) + num_equal_calls = 0 + + def __init__(self, name): + self.name = name + + def __str__(self): + return self.name + + def __repr__(self): + return f"Node(name={self.name})" + + def __equals__(self, other): + Node.num_equal_calls += 1 + return self.name == other.name + + +@pytest.fixture +def cache(): + Node.num_equal_calls = 0 + cache = Node.__cache__ + yield cache + assert not cache + + +def pair(a, b): + # for same ordering with comparable + if id(a) < id(b): + return (a, b) + else: + return (b, a) + + +def test_comparable_basic(cache): + a = Node(name="a") + b = Node(name="a") + c = Node(name="a") + assert a == b + assert a == c + del a + del b + del c + + +def test_comparable_caching(cache): + a = Node(name="a") + b = Node(name="b") + c = Node(name="c") + d = Node(name="d") + e = Node(name="e") + + cache[pair(a, b)] = True + cache[pair(a, c)] = False + cache[pair(c, d)] = True + cache[pair(b, d)] = False + assert len(cache) == 4 + + assert a == b + assert a != c + assert c == d + assert b != d + assert Node.num_equal_calls == 0 + + # no cache hit + assert pair(a, e) not in cache + assert a != e + assert Node.num_equal_calls == 1 + assert len(cache) == 5 + + # run only once + assert e != a + assert Node.num_equal_calls == 1 + assert pair(a, e) in cache + + +def test_comparable_garbage_collection(cache): + a = Node(name="a") + b = Node(name="b") + c = Node(name="c") + d = Node(name="d") + + cache[pair(a, b)] = True + cache[pair(a, c)] = False + cache[pair(c, d)] = True + cache[pair(b, d)] = False + + assert weakref.getweakrefcount(a) == 2 + del c + assert weakref.getweakrefcount(a) == 1 + del b + assert weakref.getweakrefcount(a) == 0 + + +def test_comparable_cache_reuse(cache): + nodes = [ + Node(name="a"), + Node(name="b"), + Node(name="c"), + Node(name="d"), + Node(name="e"), + ] + + expected = 0 + for a, b in zip(nodes, nodes): + a == a # noqa: B015 + a == b # noqa: B015 + b == a # noqa: B015 + if a != b: + expected += 1 + assert Node.num_equal_calls == expected + + assert len(cache) == expected + + # check that cache is evicted once nodes get collected + del nodes + assert len(cache) == 0 + + a = Node(name="a") + b = Node(name="a") + assert a == b + + +class OneAndOnly(Singleton): + __instances__ = weakref.WeakValueDictionary() + + +class DataType(Singleton): + __slots__ = ("nullable",) + __instances__ = weakref.WeakValueDictionary() + + def __init__(self, nullable=True): + self.nullable = nullable + + +def test_singleton_basics(): + one = OneAndOnly() + only = OneAndOnly() + assert one is only + + assert len(OneAndOnly.__instances__) == 1 + key = (OneAndOnly, (), frozendict()) + assert OneAndOnly.__instances__[key] is one + + +def test_singleton_lifetime() -> None: + one = OneAndOnly() + assert len(OneAndOnly.__instances__) == 1 + + del one + assert len(OneAndOnly.__instances__) == 0 + + +def test_singleton_with_argument() -> None: + dt1 = DataType(nullable=True) + dt2 = DataType(nullable=False) + dt3 = DataType(nullable=True) + + assert dt1 is dt3 + assert dt1 is not dt2 + assert len(DataType.__instances__) == 2 + + del dt3 + assert len(DataType.__instances__) == 2 + del dt1 + assert len(DataType.__instances__) == 1 + del dt2 + assert len(DataType.__instances__) == 0 + + +def test_final(): + class A(Final): + pass + + with pytest.raises(TypeError, match="Cannot inherit from final class .*A.*"): + + class B(A): + pass diff --git a/ibis/common/tests/test_grounds.py b/ibis/common/tests/test_grounds.py index 613e02616bdf..4daf02281fe2 100644 --- a/ibis/common/tests/test_grounds.py +++ b/ibis/common/tests/test_grounds.py @@ -18,14 +18,11 @@ varargs, varkwargs, ) -from ibis.common.caching import WeakCache -from ibis.common.collections import frozendict from ibis.common.grounds import ( Annotable, Base, Comparable, Concrete, - Final, Immutable, Singleton, ) @@ -217,26 +214,6 @@ class MyValue(Annotable, Generic[J, F]): numeric: N -def test_immutable(): - class Foo(Immutable): - __slots__ = ("a", "b") - - def __init__(self, a, b): - object.__setattr__(self, "a", a) - object.__setattr__(self, "b", b) - - foo = Foo(1, 2) - assert foo.a == 1 - assert foo.b == 2 - with pytest.raises(AttributeError): - foo.a = 2 - with pytest.raises(AttributeError): - foo.b = 3 - - assert copy.copy(foo) is foo - assert copy.deepcopy(foo) is foo - - def test_annotable(): class Between(BetweenSimple): pass @@ -922,178 +899,6 @@ def shape(self): assert "shape" in v.__slots__ -class Node(Comparable): - # override the default cache object - __cache__ = WeakCache() - __slots__ = ("name",) - num_equal_calls = 0 - - def __init__(self, name): - self.name = name - - def __str__(self): - return self.name - - def __repr__(self): - return f"Node(name={self.name})" - - def __equals__(self, other): - Node.num_equal_calls += 1 - return self.name == other.name - - -@pytest.fixture -def cache(): - Node.num_equal_calls = 0 - cache = Node.__cache__ - yield cache - assert not cache - - -def pair(a, b): - # for same ordering with comparable - if id(a) < id(b): - return (a, b) - else: - return (b, a) - - -def test_comparable_basic(cache): - a = Node(name="a") - b = Node(name="a") - c = Node(name="a") - assert a == b - assert a == c - del a - del b - del c - - -def test_comparable_caching(cache): - a = Node(name="a") - b = Node(name="b") - c = Node(name="c") - d = Node(name="d") - e = Node(name="e") - - cache[pair(a, b)] = True - cache[pair(a, c)] = False - cache[pair(c, d)] = True - cache[pair(b, d)] = False - assert len(cache) == 4 - - assert a == b - assert a != c - assert c == d - assert b != d - assert Node.num_equal_calls == 0 - - # no cache hit - assert pair(a, e) not in cache - assert a != e - assert Node.num_equal_calls == 1 - assert len(cache) == 5 - - # run only once - assert e != a - assert Node.num_equal_calls == 1 - assert pair(a, e) in cache - - -def test_comparable_garbage_collection(cache): - a = Node(name="a") - b = Node(name="b") - c = Node(name="c") - d = Node(name="d") - - cache[pair(a, b)] = True - cache[pair(a, c)] = False - cache[pair(c, d)] = True - cache[pair(b, d)] = False - - assert weakref.getweakrefcount(a) == 2 - del c - assert weakref.getweakrefcount(a) == 1 - del b - assert weakref.getweakrefcount(a) == 0 - - -def test_comparable_cache_reuse(cache): - nodes = [ - Node(name="a"), - Node(name="b"), - Node(name="c"), - Node(name="d"), - Node(name="e"), - ] - - expected = 0 - for a, b in zip(nodes, nodes): - a == a # noqa: B015 - a == b # noqa: B015 - b == a # noqa: B015 - if a != b: - expected += 1 - assert Node.num_equal_calls == expected - - assert len(cache) == expected - - # check that cache is evicted once nodes get collected - del nodes - assert len(cache) == 0 - - a = Node(name="a") - b = Node(name="a") - assert a == b - - -class OneAndOnly(Singleton): - __instances__ = weakref.WeakValueDictionary() - - -class DataType(Singleton): - __slots__ = ("nullable",) - __instances__ = weakref.WeakValueDictionary() - - def __init__(self, nullable=True): - self.nullable = nullable - - -def test_singleton_basics(): - one = OneAndOnly() - only = OneAndOnly() - assert one is only - - assert len(OneAndOnly.__instances__) == 1 - key = (OneAndOnly, (), frozendict()) - assert OneAndOnly.__instances__[key] is one - - -def test_singleton_lifetime() -> None: - one = OneAndOnly() - assert len(OneAndOnly.__instances__) == 1 - - del one - assert len(OneAndOnly.__instances__) == 0 - - -def test_singleton_with_argument() -> None: - dt1 = DataType(nullable=True) - dt2 = DataType(nullable=False) - dt3 = DataType(nullable=True) - - assert dt1 is dt3 - assert dt1 is not dt2 - assert len(DataType.__instances__) == 2 - - del dt3 - assert len(DataType.__instances__) == 2 - del dt1 - assert len(DataType.__instances__) == 1 - del dt2 - assert len(DataType.__instances__) == 0 - - def test_composition_of_annotable_and_singleton() -> None: class AnnSing(Annotable, Singleton): value = CoercedTo(int) @@ -1225,13 +1030,3 @@ class Example(Annotable): assert Example(None).value is None assert Example(1).value == 1 assert isinstance(Example(1).value, MyInt) - - -def test_final(): - class A(Final): - pass - - with pytest.raises(TypeError, match="Cannot inherit from final class .*A.*"): - - class B(A): - pass diff --git a/ibis/common/tests/test_patterns.py b/ibis/common/tests/test_patterns.py index 54acd4cd9b09..3ff50e043455 100644 --- a/ibis/common/tests/test_patterns.py +++ b/ibis/common/tests/test_patterns.py @@ -109,6 +109,12 @@ def __eq__(self, other): z = Variable("z") +def test_immutability_of_patterns(): + p = InstanceOf(int) + with pytest.raises(AttributeError): + p.types = [str] + + def test_always(): p = Always() assert p.match(1, context={}) == 1 @@ -237,7 +243,7 @@ class Box(Generic[T]): p = Pattern.from_typehint(Box[MyString]) assert isinstance(p, GenericInstanceOf) assert p.origin == Box - assert p.field_patterns == {"value": InstanceOf(MyString)} + assert p.fields == {"value": InstanceOf(MyString)} def test_coerced_to(): diff --git a/ibis/expr/rules.py b/ibis/expr/rules.py index 3612852065c9..508e0af7f9ed 100644 --- a/ibis/expr/rules.py +++ b/ibis/expr/rules.py @@ -2,6 +2,7 @@ import operator from itertools import product, starmap +from typing import Optional from public import public @@ -9,7 +10,8 @@ import ibis.expr.operations as ops from ibis import util from ibis.common.annotations import attribute -from ibis.common.patterns import CoercionError, Matcher, NoMatch +from ibis.common.grounds import Concrete +from ibis.common.patterns import CoercionError, NoMatch, Pattern from ibis.common.temporal import IntervalUnit @@ -174,7 +176,7 @@ def _arg_type_error_format(op): return f"{op.name}:{op.dtype}" -class ValueOf(Matcher): +class ValueOf(Concrete, Pattern): """Match a value of a specific type **instance**. This is different from the Value[T] annotations which construct @@ -187,11 +189,7 @@ class ValueOf(Matcher): The datatype the constructed Value instance should conform to. """ - __slots__ = ("dtype",) - - def __init__(self, dtype=None): - dtype = None if dtype is None else dt.dtype(dtype) - super().__init__(dtype) + dtype: Optional[dt.DataType] = None def match(self, value, context): try: