Skip to content

Commit

Permalink
refactor(common): turn annotations into slotted classes
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs authored and cpcloud committed Aug 20, 2023
1 parent 569aa12 commit 0770e92
Show file tree
Hide file tree
Showing 24 changed files with 265 additions and 227 deletions.
169 changes: 66 additions & 103 deletions ibis/common/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@

import functools
import inspect
import types
from typing import Any as AnyType
from typing import Callable

from ibis.common.bases import Immutable, Slotted
from ibis.common.patterns import (
Any,
FrozenDictOf,
Function,
NoMatch,
Option,
Pattern,
TupleOf,
)
from ibis.common.patterns import pattern as ensure_pattern
from ibis.common.typing import get_type_hints

EMPTY = inspect.Parameter.empty # marker for missing argument
Expand All @@ -28,58 +29,21 @@ class ValidationError(Exception):
...


class Annotation:
class Annotation(Slotted, Immutable):
"""Base class for all annotations.
Annotations are used to mark fields in a class and to validate them.
Parameters
----------
pattern : Pattern, default noop
Pattern to validate the field.
default : Any, default EMPTY
Default value of the field.
typehint : type, default EMPTY
Type of the field, not used for validation.
"""

__slots__ = ("_pattern", "_default", "_typehint")
_pattern: Pattern | Callable | None
_default: AnyType
_typehint: AnyType

def __init__(self, pattern=None, default=EMPTY, typehint=EMPTY):
if pattern is None or isinstance(pattern, Pattern):
pass
elif callable(pattern):
pattern = Function(pattern)
else:
raise TypeError(f"Unsupported pattern {pattern!r}")
self._pattern = pattern
self._default = default
self._typehint = typehint

def __eq__(self, other):
return (
type(self) is type(other)
and self._pattern == other._pattern
and self._default == other._default
and self._typehint == other._typehint
)

def __repr__(self):
return (
f"{self.__class__.__name__}(pattern={self._pattern!r}, "
f"default={self._default!r}, typehint={self._typehint!r})"
)
__slots__ = ()

def validate(self, arg, context=None):
if self._pattern is None:
if self.pattern is None:
return arg

result = self._pattern.match(arg, context)
result = self.pattern.match(arg, context)
if result is NoMatch:
raise ValidationError(f"{arg!r} doesn't match {self._pattern!r}")
raise ValidationError(f"{arg!r} doesn't match {self.pattern!r}")

return result

Expand All @@ -98,10 +62,14 @@ class Attribute(Annotation):
Callable to compute the default value of the field.
"""

@classmethod
def default(self, fn):
"""Annotation to mark a field with a default value computed by a callable."""
return Attribute(default=fn)
__slots__ = ("pattern", "default")
pattern: Pattern
default: AnyType

def __init__(self, pattern: Pattern | None = None, default: AnyType = EMPTY):
if pattern is not None:
pattern = ensure_pattern(pattern)
super().__init__(pattern=pattern, default=default)

def initialize(self, this: AnyType) -> AnyType:
"""Compute the default value of the field.
Expand All @@ -115,14 +83,18 @@ def initialize(self, this: AnyType) -> AnyType:
-------
The default value for the field.
"""
if self._default is EMPTY:
if self.default is EMPTY:
return EMPTY
elif callable(self._default):
value = self._default(this)
elif callable(self.default):
value = self.default(this)
else:
value = self._default
value = self.default
return self.validate(value, this)

def __call__(self, default):
"""Needed to support the decorator syntax."""
return self.__class__(self.pattern, default)


class Argument(Annotation):
"""Annotation type for all fields which should be passed as arguments.
Expand All @@ -140,8 +112,11 @@ class Argument(Annotation):
Defaults to positional or keyword.
"""

__slots__ = ("_kind",)
_kind: int
__slots__ = ("pattern", "default", "typehint", "kind")
pattern: Pattern
default: AnyType
typehint: AnyType
kind: int

def __init__(
self,
Expand All @@ -150,38 +125,43 @@ def __init__(
typehint: type | None = None,
kind: int = POSITIONAL_OR_KEYWORD,
):
super().__init__(pattern, default, typehint)
self._kind = kind
if pattern is not None:
pattern = ensure_pattern(pattern)
super().__init__(pattern=pattern, default=default, typehint=typehint, kind=kind)

@classmethod
def required(cls, pattern=None, **kwargs):
"""Annotation to mark a mandatory argument."""
return cls(pattern, **kwargs)

@classmethod
def default(cls, default, pattern=None, **kwargs):
"""Annotation to allow missing arguments with a default value."""
return cls(pattern, default, **kwargs)
def attribute(pattern=None, default=EMPTY):
"""Annotation to mark a field in a class."""
if default is EMPTY and isinstance(pattern, (types.FunctionType, types.MethodType)):
return Attribute(default=pattern)
else:
return Attribute(pattern, default=default)

@classmethod
def optional(cls, pattern=None, default=None, **kwargs):
"""Annotation to allow and treat `None` values as missing arguments."""
if pattern is None:
pattern = Option(Any(), default=default)
else:
pattern = Option(pattern, default=default)
return cls(pattern, default=None, **kwargs)

@classmethod
def varargs(cls, pattern=None, **kwargs):
"""Annotation to mark a variable length positional argument."""
pattern = None if pattern is None else TupleOf(pattern)
return cls(pattern, kind=VAR_POSITIONAL, **kwargs)
def argument(pattern=None, default=EMPTY, typehint=None):
"""Annotation type for all fields which should be passed as arguments."""
return Argument(pattern, default=default, typehint=typehint)


def optional(pattern=None, default=None, typehint=None):
"""Annotation to allow and treat `None` values as missing arguments."""
if pattern is None:
pattern = Option(Any(), default=default)
else:
pattern = Option(pattern, default=default)
return Argument(pattern, default=None, typehint=typehint)

@classmethod
def varkwargs(cls, pattern=None, **kwargs):
pattern = None if pattern is None else FrozenDictOf(Any(), pattern)
return cls(pattern, kind=VAR_KEYWORD, **kwargs)

def varargs(pattern=None, typehint=None):
"""Annotation to mark a variable length positional arguments."""
pattern = None if pattern is None else TupleOf(pattern)
return Argument(pattern, kind=VAR_POSITIONAL, typehint=typehint)


def varkwargs(pattern=None, typehint=None):
"""Annotation to mark a variable length keyword arguments."""
pattern = None if pattern is None else FrozenDictOf(Any(), pattern)
return Argument(pattern, kind=VAR_KEYWORD, typehint=typehint)


class Parameter(inspect.Parameter):
Expand All @@ -196,8 +176,8 @@ def __init__(self, name, annotation):
)
super().__init__(
name,
kind=annotation._kind,
default=annotation._default,
kind=annotation.kind,
default=annotation.default,
annotation=annotation,
)

Expand Down Expand Up @@ -309,15 +289,11 @@ def from_callable(cls, fn, patterns=None, return_pattern=None):
pattern = None

if kind is VAR_POSITIONAL:
annot = Argument.varargs(pattern, typehint=typehint)
annot = varargs(pattern, typehint=typehint)
elif kind is VAR_KEYWORD:
annot = Argument.varkwargs(pattern, typehint=typehint)
elif default is EMPTY:
annot = Argument.required(pattern, kind=kind, typehint=typehint)
annot = varkwargs(pattern, typehint=typehint)
else:
annot = Argument.default(
default, pattern, kind=param.kind, typehint=typehint
)
annot = Argument(pattern, kind=kind, default=default, typehint=typehint)

parameters.append(Parameter(param.name, annot))

Expand Down Expand Up @@ -428,19 +404,6 @@ def validate_return(self, value, context):
return result


# aliases for convenience
argument = Argument
attribute = Attribute
default = Argument.default
optional = Argument.optional
required = Argument.required
varargs = Argument.varargs
varkwargs = Argument.varkwargs


# TODO(kszucs): try to cache pattern objects


def annotated(_1=None, _2=None, _3=None, **kwargs):
"""Create functions with arguments validated at runtime.
Expand Down
37 changes: 24 additions & 13 deletions ibis/common/bases.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from abc import ABCMeta, abstractmethod
from collections.abc import Hashable
from typing import TYPE_CHECKING, Any
from weakref import WeakValueDictionary

Expand Down Expand Up @@ -154,20 +155,16 @@ def __cached_equals__(self, other) -> bool:


class Slotted(Base):
"""A lightweight alternative to `ibis.common.grounds.Concrete`.
"""A lightweight alternative to `ibis.common.grounds.Annotable`.
This class is used to create immutable dataclasses with slots and a precomputed
hash value for quicker dictionary lookups.
The class is mostly used to reduce boilerplate code.
"""

__slots__ = ("__precomputed_hash__",)
__precomputed_hash__: int
__slots__ = ()

def __init__(self, **kwargs) -> None:
for name, value in kwargs.items():
object.__setattr__(self, name, value)
hashvalue = hash(tuple(kwargs.values()))
object.__setattr__(self, "__precomputed_hash__", hashvalue)

def __eq__(self, other) -> bool:
if self is other:
Expand All @@ -176,12 +173,6 @@ def __eq__(self, other) -> bool:
return NotImplemented
return all(getattr(self, n) == getattr(other, n) for n in self.__slots__)

def __hash__(self) -> int:
return self.__precomputed_hash__

def __setattr__(self, name, value) -> None:
raise AttributeError("Can't set attributes on an immutable instance")

def __repr__(self):
fields = {k: getattr(self, k) for k in self.__slots__}
fieldstring = ", ".join(f"{k}={v!r}" for k, v in fields.items())
Expand All @@ -190,3 +181,23 @@ def __repr__(self):
def __rich_repr__(self):
for name in self.__slots__:
yield name, getattr(self, name)


class FrozenSlotted(Slotted, Immutable, Hashable):
"""A lightweight alternative to `ibis.common.grounds.Concrete`.
This class is used to create immutable dataclasses with slots and a precomputed
hash value for quicker dictionary lookups.
"""

__slots__ = ("__precomputed_hash__",)
__precomputed_hash__: int

def __init__(self, **kwargs) -> None:
for name, value in kwargs.items():
object.__setattr__(self, name, value)
hashvalue = hash(tuple(kwargs.values()))
object.__setattr__(self, "__precomputed_hash__", hashvalue)

def __hash__(self) -> int:
return self.__precomputed_hash__
6 changes: 3 additions & 3 deletions ibis/common/grounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,16 @@ def __new__(metacls, clsname, bases, dct, **kwargs):
continue
pattern = Pattern.from_typehint(typehint)
if name in dct:
dct[name] = Argument.default(dct[name], pattern, typehint=typehint)
dct[name] = Argument(pattern, default=dct[name], typehint=typehint)
else:
dct[name] = Argument.required(pattern, typehint=typehint)
dct[name] = Argument(pattern, typehint=typehint)

# collect the newly defined annotations
slots = list(dct.pop("__slots__", []))
namespace, arguments = {}, {}
for name, attrib in dct.items():
if isinstance(attrib, Pattern):
arguments[name] = Argument.required(attrib)
arguments[name] = Argument(attrib)
slots.append(name)
elif isinstance(attrib, Argument):
arguments[name] = attrib
Expand Down
3 changes: 2 additions & 1 deletion ibis/common/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
import toolz
from typing_extensions import GenericMeta, Self, get_args, get_origin

from ibis.common.bases import Singleton, Slotted
from ibis.common.bases import FrozenSlotted as Slotted
from ibis.common.bases import Singleton
from ibis.common.collections import FrozenDict, RewindableIterator, frozendict
from ibis.common.dispatch import lazy_singledispatch
from ibis.common.typing import (
Expand Down
Loading

0 comments on commit 0770e92

Please sign in to comment.