From 0770e92485e7cb5449c0359b186421efb73550be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Sat, 19 Aug 2023 10:58:47 +0200 Subject: [PATCH] refactor(common): turn annotations into slotted classes --- ibis/common/annotations.py | 169 ++++++++----------- ibis/common/bases.py | 37 ++-- ibis/common/grounds.py | 6 +- ibis/common/patterns.py | 3 +- ibis/common/tests/test_annotations.py | 118 +++++++++---- ibis/common/tests/test_grounds.py | 51 ++++-- ibis/common/tests/test_grounds_benchmarks.py | 6 +- ibis/expr/datatypes/core.py | 4 +- ibis/expr/operations/analytic.py | 4 +- ibis/expr/operations/arrays.py | 20 +-- ibis/expr/operations/generic.py | 6 +- ibis/expr/operations/histograms.py | 2 +- ibis/expr/operations/logical.py | 4 +- ibis/expr/operations/maps.py | 8 +- ibis/expr/operations/numeric.py | 4 +- ibis/expr/operations/reductions.py | 10 +- ibis/expr/operations/relations.py | 10 +- ibis/expr/operations/strings.py | 2 +- ibis/expr/operations/structs.py | 4 +- ibis/expr/operations/temporal.py | 4 +- ibis/expr/operations/udf.py | 6 +- ibis/expr/rules.py | 6 +- ibis/expr/schema.py | 6 +- ibis/selectors.py | 2 +- 24 files changed, 265 insertions(+), 227 deletions(-) diff --git a/ibis/common/annotations.py b/ibis/common/annotations.py index 0d75fce913d0..6f711f9a9092 100644 --- a/ibis/common/annotations.py +++ b/ibis/common/annotations.py @@ -2,18 +2,19 @@ import functools import inspect +import types from typing import Any as AnyType -from typing import Callable +from ibis.common.bases import Immutable, Slotted from ibis.common.patterns import ( Any, FrozenDictOf, - Function, NoMatch, Option, Pattern, TupleOf, ) +from ibis.common.patterns import pattern as ensure_pattern from ibis.common.typing import get_type_hints EMPTY = inspect.Parameter.empty # marker for missing argument @@ -28,58 +29,21 @@ class ValidationError(Exception): ... -class Annotation: +class Annotation(Slotted, Immutable): """Base class for all annotations. Annotations are used to mark fields in a class and to validate them. - - Parameters - ---------- - pattern : Pattern, default noop - Pattern to validate the field. - default : Any, default EMPTY - Default value of the field. - typehint : type, default EMPTY - Type of the field, not used for validation. """ - __slots__ = ("_pattern", "_default", "_typehint") - _pattern: Pattern | Callable | None - _default: AnyType - _typehint: AnyType - - def __init__(self, pattern=None, default=EMPTY, typehint=EMPTY): - if pattern is None or isinstance(pattern, Pattern): - pass - elif callable(pattern): - pattern = Function(pattern) - else: - raise TypeError(f"Unsupported pattern {pattern!r}") - self._pattern = pattern - self._default = default - self._typehint = typehint - - def __eq__(self, other): - return ( - type(self) is type(other) - and self._pattern == other._pattern - and self._default == other._default - and self._typehint == other._typehint - ) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(pattern={self._pattern!r}, " - f"default={self._default!r}, typehint={self._typehint!r})" - ) + __slots__ = () def validate(self, arg, context=None): - if self._pattern is None: + if self.pattern is None: return arg - result = self._pattern.match(arg, context) + result = self.pattern.match(arg, context) if result is NoMatch: - raise ValidationError(f"{arg!r} doesn't match {self._pattern!r}") + raise ValidationError(f"{arg!r} doesn't match {self.pattern!r}") return result @@ -98,10 +62,14 @@ class Attribute(Annotation): Callable to compute the default value of the field. """ - @classmethod - def default(self, fn): - """Annotation to mark a field with a default value computed by a callable.""" - return Attribute(default=fn) + __slots__ = ("pattern", "default") + pattern: Pattern + default: AnyType + + def __init__(self, pattern: Pattern | None = None, default: AnyType = EMPTY): + if pattern is not None: + pattern = ensure_pattern(pattern) + super().__init__(pattern=pattern, default=default) def initialize(self, this: AnyType) -> AnyType: """Compute the default value of the field. @@ -115,14 +83,18 @@ def initialize(self, this: AnyType) -> AnyType: ------- The default value for the field. """ - if self._default is EMPTY: + if self.default is EMPTY: return EMPTY - elif callable(self._default): - value = self._default(this) + elif callable(self.default): + value = self.default(this) else: - value = self._default + value = self.default return self.validate(value, this) + def __call__(self, default): + """Needed to support the decorator syntax.""" + return self.__class__(self.pattern, default) + class Argument(Annotation): """Annotation type for all fields which should be passed as arguments. @@ -140,8 +112,11 @@ class Argument(Annotation): Defaults to positional or keyword. """ - __slots__ = ("_kind",) - _kind: int + __slots__ = ("pattern", "default", "typehint", "kind") + pattern: Pattern + default: AnyType + typehint: AnyType + kind: int def __init__( self, @@ -150,38 +125,43 @@ def __init__( typehint: type | None = None, kind: int = POSITIONAL_OR_KEYWORD, ): - super().__init__(pattern, default, typehint) - self._kind = kind + if pattern is not None: + pattern = ensure_pattern(pattern) + super().__init__(pattern=pattern, default=default, typehint=typehint, kind=kind) - @classmethod - def required(cls, pattern=None, **kwargs): - """Annotation to mark a mandatory argument.""" - return cls(pattern, **kwargs) - @classmethod - def default(cls, default, pattern=None, **kwargs): - """Annotation to allow missing arguments with a default value.""" - return cls(pattern, default, **kwargs) +def attribute(pattern=None, default=EMPTY): + """Annotation to mark a field in a class.""" + if default is EMPTY and isinstance(pattern, (types.FunctionType, types.MethodType)): + return Attribute(default=pattern) + else: + return Attribute(pattern, default=default) - @classmethod - def optional(cls, pattern=None, default=None, **kwargs): - """Annotation to allow and treat `None` values as missing arguments.""" - if pattern is None: - pattern = Option(Any(), default=default) - else: - pattern = Option(pattern, default=default) - return cls(pattern, default=None, **kwargs) - @classmethod - def varargs(cls, pattern=None, **kwargs): - """Annotation to mark a variable length positional argument.""" - pattern = None if pattern is None else TupleOf(pattern) - return cls(pattern, kind=VAR_POSITIONAL, **kwargs) +def argument(pattern=None, default=EMPTY, typehint=None): + """Annotation type for all fields which should be passed as arguments.""" + return Argument(pattern, default=default, typehint=typehint) + + +def optional(pattern=None, default=None, typehint=None): + """Annotation to allow and treat `None` values as missing arguments.""" + if pattern is None: + pattern = Option(Any(), default=default) + else: + pattern = Option(pattern, default=default) + return Argument(pattern, default=None, typehint=typehint) - @classmethod - def varkwargs(cls, pattern=None, **kwargs): - pattern = None if pattern is None else FrozenDictOf(Any(), pattern) - return cls(pattern, kind=VAR_KEYWORD, **kwargs) + +def varargs(pattern=None, typehint=None): + """Annotation to mark a variable length positional arguments.""" + pattern = None if pattern is None else TupleOf(pattern) + return Argument(pattern, kind=VAR_POSITIONAL, typehint=typehint) + + +def varkwargs(pattern=None, typehint=None): + """Annotation to mark a variable length keyword arguments.""" + pattern = None if pattern is None else FrozenDictOf(Any(), pattern) + return Argument(pattern, kind=VAR_KEYWORD, typehint=typehint) class Parameter(inspect.Parameter): @@ -196,8 +176,8 @@ def __init__(self, name, annotation): ) super().__init__( name, - kind=annotation._kind, - default=annotation._default, + kind=annotation.kind, + default=annotation.default, annotation=annotation, ) @@ -309,15 +289,11 @@ def from_callable(cls, fn, patterns=None, return_pattern=None): pattern = None if kind is VAR_POSITIONAL: - annot = Argument.varargs(pattern, typehint=typehint) + annot = varargs(pattern, typehint=typehint) elif kind is VAR_KEYWORD: - annot = Argument.varkwargs(pattern, typehint=typehint) - elif default is EMPTY: - annot = Argument.required(pattern, kind=kind, typehint=typehint) + annot = varkwargs(pattern, typehint=typehint) else: - annot = Argument.default( - default, pattern, kind=param.kind, typehint=typehint - ) + annot = Argument(pattern, kind=kind, default=default, typehint=typehint) parameters.append(Parameter(param.name, annot)) @@ -428,19 +404,6 @@ def validate_return(self, value, context): return result -# aliases for convenience -argument = Argument -attribute = Attribute -default = Argument.default -optional = Argument.optional -required = Argument.required -varargs = Argument.varargs -varkwargs = Argument.varkwargs - - -# TODO(kszucs): try to cache pattern objects - - def annotated(_1=None, _2=None, _3=None, **kwargs): """Create functions with arguments validated at runtime. diff --git a/ibis/common/bases.py b/ibis/common/bases.py index 6b6d723452fa..89ef17471728 100644 --- a/ibis/common/bases.py +++ b/ibis/common/bases.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import ABCMeta, abstractmethod +from collections.abc import Hashable from typing import TYPE_CHECKING, Any from weakref import WeakValueDictionary @@ -154,20 +155,16 @@ def __cached_equals__(self, other) -> bool: class Slotted(Base): - """A lightweight alternative to `ibis.common.grounds.Concrete`. + """A lightweight alternative to `ibis.common.grounds.Annotable`. - This class is used to create immutable dataclasses with slots and a precomputed - hash value for quicker dictionary lookups. + The class is mostly used to reduce boilerplate code. """ - __slots__ = ("__precomputed_hash__",) - __precomputed_hash__: int + __slots__ = () def __init__(self, **kwargs) -> None: for name, value in kwargs.items(): object.__setattr__(self, name, value) - hashvalue = hash(tuple(kwargs.values())) - object.__setattr__(self, "__precomputed_hash__", hashvalue) def __eq__(self, other) -> bool: if self is other: @@ -176,12 +173,6 @@ def __eq__(self, other) -> bool: return NotImplemented return all(getattr(self, n) == getattr(other, n) for n in self.__slots__) - def __hash__(self) -> int: - return self.__precomputed_hash__ - - def __setattr__(self, name, value) -> None: - raise AttributeError("Can't set attributes on an immutable instance") - def __repr__(self): fields = {k: getattr(self, k) for k in self.__slots__} fieldstring = ", ".join(f"{k}={v!r}" for k, v in fields.items()) @@ -190,3 +181,23 @@ def __repr__(self): def __rich_repr__(self): for name in self.__slots__: yield name, getattr(self, name) + + +class FrozenSlotted(Slotted, Immutable, Hashable): + """A lightweight alternative to `ibis.common.grounds.Concrete`. + + This class is used to create immutable dataclasses with slots and a precomputed + hash value for quicker dictionary lookups. + """ + + __slots__ = ("__precomputed_hash__",) + __precomputed_hash__: int + + def __init__(self, **kwargs) -> None: + for name, value in kwargs.items(): + object.__setattr__(self, name, value) + hashvalue = hash(tuple(kwargs.values())) + object.__setattr__(self, "__precomputed_hash__", hashvalue) + + def __hash__(self) -> int: + return self.__precomputed_hash__ diff --git a/ibis/common/grounds.py b/ibis/common/grounds.py index 35505d5d28d6..8f7fb6dbacfd 100644 --- a/ibis/common/grounds.py +++ b/ibis/common/grounds.py @@ -57,16 +57,16 @@ def __new__(metacls, clsname, bases, dct, **kwargs): continue pattern = Pattern.from_typehint(typehint) if name in dct: - dct[name] = Argument.default(dct[name], pattern, typehint=typehint) + dct[name] = Argument(pattern, default=dct[name], typehint=typehint) else: - dct[name] = Argument.required(pattern, typehint=typehint) + dct[name] = Argument(pattern, typehint=typehint) # collect the newly defined annotations slots = list(dct.pop("__slots__", [])) namespace, arguments = {}, {} for name, attrib in dct.items(): if isinstance(attrib, Pattern): - arguments[name] = Argument.required(attrib) + arguments[name] = Argument(attrib) slots.append(name) elif isinstance(attrib, Argument): arguments[name] = attrib diff --git a/ibis/common/patterns.py b/ibis/common/patterns.py index 900ef7880558..fc7ac789773d 100644 --- a/ibis/common/patterns.py +++ b/ibis/common/patterns.py @@ -23,7 +23,8 @@ import toolz from typing_extensions import GenericMeta, Self, get_args, get_origin -from ibis.common.bases import Singleton, Slotted +from ibis.common.bases import FrozenSlotted as Slotted +from ibis.common.bases import Singleton from ibis.common.collections import FrozenDict, RewindableIterator, frozendict from ibis.common.dispatch import lazy_singledispatch from ibis.common.typing import ( diff --git a/ibis/common/tests/test_annotations.py b/ibis/common/tests/test_annotations.py index db9d9d5bb41b..3a12c710f875 100644 --- a/ibis/common/tests/test_annotations.py +++ b/ibis/common/tests/test_annotations.py @@ -12,6 +12,9 @@ Signature, ValidationError, annotated, + argument, + attribute, + optional, ) from ibis.common.patterns import ( Any, @@ -26,16 +29,66 @@ is_int = InstanceOf(int) +def test_argument_factory(): + a = argument(is_int, default=1, typehint=int) + assert a == Argument(is_int, default=1, typehint=int) + + a = argument(is_int, default=1) + assert a == Argument(is_int, default=1) + + a = argument(is_int) + assert a == Argument(is_int) + + +def test_attribute_factory(): + a = attribute(is_int, default=1) + assert a == Attribute(is_int, default=1) + + a = attribute(is_int) + assert a == Attribute(is_int) + + a = attribute(default=2) + assert a == Attribute(default=2) + + a = attribute(int, default=2) + assert a == Attribute(int, default=2) + + +def test_annotations_are_immutable(): + a = argument(is_int, default=1) + with pytest.raises(AttributeError): + a.pattern = Any() + with pytest.raises(AttributeError): + a.default = 2 + + a = attribute(is_int, default=1) + with pytest.raises(AttributeError): + a.pattern = Any() + with pytest.raises(AttributeError): + a.default = 2 + + +def test_annotations_are_not_hashable(): + # in order to use the with mutable defaults + a = argument(is_int, default=1) + with pytest.raises(TypeError, match="unhashable type: 'Argument'"): + hash(a) + + a = attribute(is_int, default=1) + with pytest.raises(TypeError, match="unhashable type: 'Attribute'"): + hash(a) + + def test_argument_repr(): argument = Argument(is_int, typehint=int, default=None) assert repr(argument) == ( "Argument(pattern=InstanceOf(type=), default=None, " - "typehint=)" + "typehint=, kind=<_ParameterKind.POSITIONAL_OR_KEYWORD: 1>)" ) def test_default_argument(): - annotation = Argument.default(pattern=lambda x, context: int(x), default=3) + annotation = Argument(pattern=lambda x, context: int(x), default=3) assert annotation.validate(1) == 1 with pytest.raises(TypeError): annotation.validate(None) @@ -46,43 +99,36 @@ def test_default_argument(): [(None, None), (0, 0), ("default", "default")], ) def test_optional_argument(default, expected): - annotation = Argument.optional(default=default) + annotation = optional(default=default) assert annotation.validate(None) == expected @pytest.mark.parametrize( ("argument", "value", "expected"), [ - (Argument.optional(Any(), default=None), None, None), - (Argument.optional(Any(), default=None), "three", "three"), - (Argument.optional(Any(), default=1), None, 1), - (Argument.optional(CoercedTo(int), default=11), None, 11), - (Argument.optional(CoercedTo(int), default=None), None, None), - (Argument.optional(CoercedTo(int), default=None), 18, 18), - (Argument.optional(CoercedTo(str), default=None), "caracal", "caracal"), + (optional(Any(), default=None), None, None), + (optional(Any(), default=None), "three", "three"), + (optional(Any(), default=1), None, 1), + (optional(CoercedTo(int), default=11), None, 11), + (optional(CoercedTo(int), default=None), None, None), + (optional(CoercedTo(int), default=None), 18, 18), + (optional(CoercedTo(str), default=None), "caracal", "caracal"), ], ) def test_valid_optional(argument, value, expected): assert argument.validate(value) == expected -@pytest.mark.parametrize( - ("arg", "value", "expected"), - [ - (Argument.optional(is_int, default=""), None, TypeError), - (Argument.optional(is_int), "lynx", TypeError), - ], -) -def test_invalid_optional_argument(arg, value, expected): - with pytest.raises(expected): - arg(value) +def test_invalid_optional_argument(): + with pytest.raises(ValidationError): + optional(is_int).validate("lynx") def test_initialized(): class Foo: a = 10 - field = Attribute.default(lambda self: self.a + 10) + field = Attribute(default=lambda self: self.a + 10) assert field == field assert field.initialize(Foo) == 20 @@ -96,7 +142,7 @@ def test_parameter(): def fn(x, this): return int(x) + this["other"] - annot = Argument.required(fn) + annot = argument(fn) p = Parameter("test", annotation=annot) assert p.annotation is annot @@ -106,9 +152,9 @@ def fn(x, this): with pytest.raises(TypeError): p.annotation.validate({}, valid=inspect.Parameter.empty) - ofn = Argument.optional(fn) + ofn = optional(fn) op = Parameter("test", annotation=ofn) - assert op.annotation._pattern == Option(fn, default=None) + assert op.annotation.pattern == Option(fn, default=None) assert op.default is None assert op.annotation.validate(None, {"other": 1}) is None @@ -123,8 +169,8 @@ def to_int(x, this): def add_other(x, this): return int(x) + this["other"] - other = Parameter("other", annotation=Argument.required(to_int)) - this = Parameter("this", annotation=Argument.required(add_other)) + other = Parameter("other", annotation=Argument(to_int)) + this = Parameter("this", annotation=Argument(add_other)) sig = Signature(parameters=[other, this]) assert sig.validate(1, 2) == {"other": 1, "this": 3} @@ -155,9 +201,9 @@ def test(a: int, b: int, *args: int): assert sig.validate(2, 3) == {"a": 2, "b": 3, "args": ()} assert sig.validate(2, 3, 4) == {"a": 2, "b": 3, "args": (4,)} assert sig.validate(2, 3, 4, 5) == {"a": 2, "b": 3, "args": (4, 5)} - assert sig.parameters["a"].annotation._typehint is int - assert sig.parameters["b"].annotation._typehint is int - assert sig.parameters["args"].annotation._typehint is int + assert sig.parameters["a"].annotation.typehint is int + assert sig.parameters["b"].annotation.typehint is int + assert sig.parameters["args"].annotation.typehint is int with pytest.raises(ValidationError): sig.validate(2, 3, 4, "5") @@ -210,8 +256,8 @@ def to_int(x, this): def add_other(x, this): return int(x) + this["other"] - other = Parameter("other", annotation=Argument.required(to_int)) - this = Parameter("this", annotation=Argument.required(add_other)) + other = Parameter("other", annotation=Argument(to_int)) + this = Parameter("this", annotation=Argument(add_other)) sig = Signature(parameters=[other, this]) params = sig.validate(1, this=2) @@ -221,14 +267,14 @@ def add_other(x, this): assert kwargs == {} -a = Parameter("a", annotation=Argument.required(CoercedTo(float))) -b = Parameter("b", annotation=Argument.required(CoercedTo(float))) -c = Parameter("c", annotation=Argument.default(default=0, pattern=CoercedTo(float))) +a = Parameter("a", annotation=Argument(CoercedTo(float))) +b = Parameter("b", annotation=Argument(CoercedTo(float))) +c = Parameter("c", annotation=Argument(CoercedTo(float), default=0)) d = Parameter( "d", - annotation=Argument.default(default=tuple(), pattern=TupleOf(CoercedTo(float))), + annotation=Argument(TupleOf(CoercedTo(float)), default=()), ) -e = Parameter("e", annotation=Argument.optional(pattern=CoercedTo(float))) +e = Parameter("e", annotation=Argument(Option(CoercedTo(float)), default=None)) sig = Signature(parameters=[a, b, c, d, e]) diff --git a/ibis/common/tests/test_grounds.py b/ibis/common/tests/test_grounds.py index a56fe9057eaf..7e1f6aaba87c 100644 --- a/ibis/common/tests/test_grounds.py +++ b/ibis/common/tests/test_grounds.py @@ -9,13 +9,13 @@ import pytest from ibis.common.annotations import ( + Argument, Parameter, Signature, ValidationError, argument, attribute, optional, - required, varargs, varkwargs, ) @@ -76,7 +76,7 @@ class BetweenWithCalculated(Concrete): lower = optional(is_int, default=0) upper = optional(is_int, default=None) - @attribute.default + @attribute def calculated(self): return self.value + self.lower @@ -418,22 +418,22 @@ class IntAddClip(FloatAddClip, IntBinop): assert IntBinop.__signature__ == Signature( [ - Parameter("left", annotation=required(is_int)), - Parameter("right", annotation=required(is_int)), + Parameter("left", annotation=Argument(is_int)), + Parameter("right", annotation=Argument(is_int)), ] ) assert FloatAddRhs.__signature__ == Signature( [ - Parameter("left", annotation=required(is_int)), - Parameter("right", annotation=required(is_float)), + Parameter("left", annotation=Argument(is_int)), + Parameter("right", annotation=Argument(is_float)), ] ) assert FloatAddClip.__signature__ == Signature( [ - Parameter("left", annotation=required(is_float)), - Parameter("right", annotation=required(is_float)), + Parameter("left", annotation=Argument(is_float)), + Parameter("right", annotation=Argument(is_float)), Parameter("clip_lower", annotation=optional(is_int, default=0)), Parameter("clip_upper", annotation=optional(is_int, default=10)), ] @@ -441,8 +441,8 @@ class IntAddClip(FloatAddClip, IntBinop): assert IntAddClip.__signature__ == Signature( [ - Parameter("left", annotation=required(is_int)), - Parameter("right", annotation=required(is_int)), + Parameter("left", annotation=Argument(is_int)), + Parameter("right", annotation=Argument(is_int)), Parameter("clip_lower", annotation=optional(is_int, default=0)), Parameter("clip_upper", annotation=optional(is_int, default=10)), ] @@ -629,9 +629,9 @@ class Op(Annotable): def test_copy_mutable_with_default_attribute(): class Test(Annotable): a = attribute(InstanceOf(dict), default={}) - b = argument(InstanceOf(str)) + b = argument(InstanceOf(str)) # required argument - @attribute.default + @attribute def c(self): return self.b.upper() @@ -777,7 +777,7 @@ class BaseValue(Annotable): class Value2(BaseValue): - @attribute.default + @attribute def k(self): return 3 @@ -858,7 +858,7 @@ def test_initialized_attribute_basics(): class Value(Annotable): a = is_int - @attribute.default + @attribute def double_a(self): return 2 * self.a @@ -869,6 +869,27 @@ def double_a(self): assert "double_a" in Value.__slots__ +def test_initialized_attribute_with_validation(): + class Value(Annotable): + a = is_int + + @attribute(int) + def double_a(self): + return 2 * self.a + + op = Value(1) + assert op.a == 1 + assert op.double_a == 2 + assert len(Value.__attributes__) == 1 + assert "double_a" in Value.__slots__ + + op.double_a = 3 + assert op.double_a == 3 + + with pytest.raises(ValidationError): + op.double_a = "foo" + + def test_initialized_attribute_mixed_with_classvar(): class Value(Annotable): arg = is_int @@ -880,7 +901,7 @@ class Reduction(Value): shape = "scalar" class Variadic(Value): - @attribute.default + @attribute def shape(self): if self.arg > 10: return "columnar" diff --git a/ibis/common/tests/test_grounds_benchmarks.py b/ibis/common/tests/test_grounds_benchmarks.py index 9422d07efad2..e48a979184d8 100644 --- a/ibis/common/tests/test_grounds_benchmarks.py +++ b/ibis/common/tests/test_grounds_benchmarks.py @@ -15,15 +15,15 @@ class MyObject(Concrete): c: tuple[int, ...] d: frozendict[str, int] - @attribute.default + @attribute def e(self): return self.a * 2 - @attribute.default + @attribute def f(self): return self.b * 2 - @attribute.default + @attribute def g(self): return self.c * 2 diff --git a/ibis/expr/datatypes/core.py b/ibis/expr/datatypes/core.py index eab4ce681369..fa942fdc70f2 100644 --- a/ibis/expr/datatypes/core.py +++ b/ibis/expr/datatypes/core.py @@ -857,12 +857,12 @@ def from_tuples( """ return cls(dict(pairs), nullable=nullable) - @attribute.default + @attribute def names(self) -> tuple[str, ...]: """Return the names of the struct's fields.""" return tuple(self.keys()) - @attribute.default + @attribute def types(self) -> tuple[DataType, ...]: """Return the types of the struct's fields.""" return tuple(self.values()) diff --git a/ibis/expr/operations/analytic.py b/ibis/expr/operations/analytic.py index 702007e0fa63..d56d5812ca9f 100644 --- a/ibis/expr/operations/analytic.py +++ b/ibis/expr/operations/analytic.py @@ -89,7 +89,7 @@ class CumulativeSum(Cumulative): arg: Column[dt.Numeric] - @attribute.default + @attribute def dtype(self): return dt.higher_precedence(self.arg.dtype.largest, dt.int64) @@ -103,7 +103,7 @@ class CumulativeMean(Cumulative): arg: Column[dt.Numeric] - @attribute.default + @attribute def dtype(self): return dt.higher_precedence(self.arg.dtype.largest, dt.float64) diff --git a/ibis/expr/operations/arrays.py b/ibis/expr/operations/arrays.py index 2e216a01293b..724ca2e37ee3 100644 --- a/ibis/expr/operations/arrays.py +++ b/ibis/expr/operations/arrays.py @@ -18,7 +18,7 @@ class ArrayColumn(Value): shape = ds.columnar - @attribute.default + @attribute def dtype(self): return dt.Array(rlz.highest_precedence_dtype(self.cols)) @@ -48,7 +48,7 @@ class ArrayIndex(Value): shape = rlz.shape_like("args") - @attribute.default + @attribute def dtype(self): return self.arg.dtype.value_type @@ -57,11 +57,11 @@ def dtype(self): class ArrayConcat(Value): arg: VarTuple[Value[dt.Array]] - @attribute.default + @attribute def dtype(self): return dt.Array(dt.highest_precedence(arg.dtype.value_type for arg in self.arg)) - @attribute.default + @attribute def shape(self): return rlz.highest_precedence_shape(self.arg) @@ -78,12 +78,12 @@ class ArrayRepeat(Value): class ArrayApply(Value): arg: Value[dt.Array] - @attribute.default + @attribute def parameter(self): (name,) = self.func.__signature__.parameters.keys() return name - @attribute.default + @attribute def result(self): arg = Argument( name=self.parameter, @@ -92,7 +92,7 @@ def result(self): ) return self.func(arg) - @attribute.default + @attribute def shape(self): return self.arg.shape @@ -101,7 +101,7 @@ def shape(self): class ArrayMap(ArrayApply): func: Callable[[Value], Value] - @attribute.default + @attribute def dtype(self) -> dt.DataType: return dt.Array(self.result.dtype) @@ -119,7 +119,7 @@ class Unnest(Value): shape = ds.columnar - @attribute.default + @attribute def dtype(self): return self.arg.dtype.value_type @@ -191,7 +191,7 @@ class ArrayZip(Value): shape = rlz.shape_like("arg") - @attribute.default + @attribute def dtype(self): return dt.Array( dt.Struct( diff --git a/ibis/expr/operations/generic.py b/ibis/expr/operations/generic.py index 0d902364ae54..bfd273fdb948 100644 --- a/ibis/expr/operations/generic.py +++ b/ibis/expr/operations/generic.py @@ -299,7 +299,7 @@ def __init__(self, cases, results, **kwargs): assert len(cases) == len(results) super().__init__(cases=cases, results=results, **kwargs) - @attribute.default + @attribute def dtype(self): values = [*self.results, self.default] return rlz.highest_precedence_dtype(values) @@ -315,12 +315,12 @@ def __init__(self, cases, results, default): assert len(cases) == len(results) super().__init__(cases=cases, results=results, default=default) - @attribute.default + @attribute def shape(self): # TODO(kszucs): can be removed after making Sequence iterable return rlz.highest_precedence_shape(self.cases) - @attribute.default + @attribute def dtype(self): exprs = [*self.results, self.default] return rlz.highest_precedence_dtype(exprs) diff --git a/ibis/expr/operations/histograms.py b/ibis/expr/operations/histograms.py index 43e71077ae9c..dd415172b16e 100644 --- a/ibis/expr/operations/histograms.py +++ b/ibis/expr/operations/histograms.py @@ -23,7 +23,7 @@ class Bucket(Value): shape = ds.columnar - @attribute.default + @attribute def dtype(self): return dt.infer(self.nbuckets) diff --git a/ibis/expr/operations/logical.py b/ibis/expr/operations/logical.py index bf350c8c2e1c..38a6e2f491fc 100644 --- a/ibis/expr/operations/logical.py +++ b/ibis/expr/operations/logical.py @@ -134,7 +134,7 @@ class InValues(Value): dtype = dt.boolean - @attribute.default + @attribute def shape(self): args = [self.value, *self.options] return rlz.highest_precedence_shape(args) @@ -164,7 +164,7 @@ class Where(Value): shape = rlz.shape_like("args") - @attribute.default + @attribute def dtype(self): return rlz.highest_precedence_dtype([self.true_expr, self.false_null_expr]) diff --git a/ibis/expr/operations/maps.py b/ibis/expr/operations/maps.py index ddf033623d72..8323d880e6dc 100644 --- a/ibis/expr/operations/maps.py +++ b/ibis/expr/operations/maps.py @@ -15,7 +15,7 @@ class Map(Value): shape = rlz.shape_like("args") - @attribute.default + @attribute def dtype(self): return dt.Map( self.keys.dtype.value_type, @@ -37,7 +37,7 @@ class MapGet(Value): shape = rlz.shape_like("args") - @attribute.default + @attribute def dtype(self): return dt.higher_precedence(self.default.dtype, self.arg.dtype.value_type) @@ -55,7 +55,7 @@ class MapContains(Value): class MapKeys(Unary): arg: Value[dt.Map] - @attribute.default + @attribute def dtype(self): return dt.Array(self.arg.dtype.key_type) @@ -64,7 +64,7 @@ def dtype(self): class MapValues(Unary): arg: Value[dt.Map] - @attribute.default + @attribute def dtype(self): return dt.Array(self.arg.dtype.value_type) diff --git a/ibis/expr/operations/numeric.py b/ibis/expr/operations/numeric.py index 7a077d142b72..2303fcf6b51f 100644 --- a/ibis/expr/operations/numeric.py +++ b/ibis/expr/operations/numeric.py @@ -202,14 +202,14 @@ class BaseConvert(Value): class MathUnary(Unary): arg: SoftNumeric - @attribute.default + @attribute def dtype(self): return dt.higher_precedence(self.arg.dtype, dt.double) @public class ExpandingMathUnary(MathUnary): - @attribute.default + @attribute def dtype(self): return dt.higher_precedence(self.arg.dtype.largest, dt.double) diff --git a/ibis/expr/operations/reductions.py b/ibis/expr/operations/reductions.py index 94280d081d9c..d57f82fcd856 100644 --- a/ibis/expr/operations/reductions.py +++ b/ibis/expr/operations/reductions.py @@ -153,7 +153,7 @@ class BitXor(Filterable, Reduction): class Sum(Filterable, Reduction): arg: Column[dt.Numeric | dt.Boolean] - @attribute.default + @attribute def dtype(self): if self.arg.dtype.is_boolean(): return dt.int64 @@ -165,7 +165,7 @@ def dtype(self): class Mean(Filterable, Reduction): arg: Column[dt.Numeric | dt.Boolean] - @attribute.default + @attribute def dtype(self): if (dtype := self.arg.dtype).is_boolean(): return dt.float64 @@ -177,7 +177,7 @@ def dtype(self): class Median(Filterable, Reduction): arg: Column[dt.Numeric | dt.Boolean] - @attribute.default + @attribute def dtype(self): return dt.higher_precedence(self.arg.dtype, dt.float64) @@ -209,7 +209,7 @@ class VarianceBase(Filterable, Reduction): arg: Column[dt.Numeric | dt.Boolean] how: Literal["sample", "pop"] - @attribute.default + @attribute def dtype(self): if (dtype := self.arg.dtype).is_decimal(): return dtype.largest @@ -327,7 +327,7 @@ class CountDistinct(Filterable, Reduction): class ArrayCollect(Filterable, Reduction): arg: Column - @attribute.default + @attribute def dtype(self): return dt.Array(self.arg.dtype) diff --git a/ibis/expr/operations/relations.py b/ibis/expr/operations/relations.py index 308611a28d56..52be5f74cb7e 100644 --- a/ibis/expr/operations/relations.py +++ b/ibis/expr/operations/relations.py @@ -389,7 +389,7 @@ def schema(self): class SelfReference(Relation): table: Relation - @attribute.default + @attribute def name(self) -> str: if (name := getattr(self.table, "name", None)) is not None: return f"{name}_ref" @@ -404,7 +404,7 @@ class Projection(Relation): table: Relation selections: VarTuple[Relation | Value] - @attribute.default + @attribute def schema(self): # Resolve schema and initialize if not self.selections: @@ -479,7 +479,7 @@ def order_by(self, keys: VarTuple[SortKey]): return Selection(self, [], sort_keys=keys) - @attribute.default + @attribute def _projection(self): return Projection(self.table, self.selections) @@ -529,7 +529,7 @@ def __init__(self, table, metrics, by, having, predicates, sort_keys): sort_keys=sort_keys, ) - @attribute.default + @attribute def schema(self): names, types = [], [] for value in self.by + self.metrics: @@ -625,7 +625,7 @@ class SQLStringView(PhysicalTable): name: str query: str - @attribute.default + @attribute def schema(self): # TODO(kszucs): avoid converting to expression backend = self.child.to_expr()._find_backend() diff --git a/ibis/expr/operations/strings.py b/ibis/expr/operations/strings.py index 24de9215e71e..871dac7f277a 100644 --- a/ibis/expr/operations/strings.py +++ b/ibis/expr/operations/strings.py @@ -139,7 +139,7 @@ class StringJoin(Value): dtype = dt.string - @attribute.default + @attribute def shape(self): return rlz.highest_precedence_shape(self.arg) diff --git a/ibis/expr/operations/structs.py b/ibis/expr/operations/structs.py index c46e0ad6cf4f..2fadade05da1 100644 --- a/ibis/expr/operations/structs.py +++ b/ibis/expr/operations/structs.py @@ -16,7 +16,7 @@ class StructField(Value): shape = rlz.shape_like("arg") - @attribute.default + @attribute def dtype(self) -> dt.DataType: struct_dtype = self.arg.dtype value_dtype = struct_dtype[self.field] @@ -34,7 +34,7 @@ class StructColumn(Value): shape = rlz.shape_like("values") - @attribute.default + @attribute def dtype(self) -> dt.DataType: dtypes = (value.dtype for value in self.values) return dt.Struct.from_tuples(zip(self.names, dtypes)) diff --git a/ibis/expr/operations/temporal.py b/ibis/expr/operations/temporal.py index bf61667d7f67..f79b776e0712 100644 --- a/ibis/expr/operations/temporal.py +++ b/ibis/expr/operations/temporal.py @@ -288,7 +288,7 @@ class TimestampDiff(Binary): @public class IntervalBinary(Binary): - @attribute.default + @attribute def dtype(self): interval_unit_args = [ arg.dtype.unit for arg in (self.left, self.right) if arg.dtype.is_interval() @@ -333,7 +333,7 @@ class IntervalFromInteger(Value): shape = rlz.shape_like("arg") - @attribute.default + @attribute def dtype(self): return dt.Interval(self.unit) diff --git a/ibis/expr/operations/udf.py b/ibis/expr/operations/udf.py index 304663dbe691..ad520ffc4282 100644 --- a/ibis/expr/operations/udf.py +++ b/ibis/expr/operations/udf.py @@ -127,13 +127,9 @@ def make_node( raise exc.MissingParameterAnnotationError(fn, name) arg = rlz.ValueOf(dt.dtype(raw_dtype)) - if (default := param.default) is EMPTY: - fields[name] = Argument.required(pattern=arg) - else: - fields[name] = Argument.default(pattern=arg, default=default) + fields[name] = Argument(pattern=arg, default=param.default) fields["dtype"] = dt.dtype(return_annotation) - fields["__input_type__"] = input_type # can't be just `fn` otherwise `fn` is assumed to be a method fields["__func__"] = property(fget=lambda _, fn=fn: fn) diff --git a/ibis/expr/rules.py b/ibis/expr/rules.py index 508e0af7f9ed..30a8210eb9c9 100644 --- a/ibis/expr/rules.py +++ b/ibis/expr/rules.py @@ -62,7 +62,7 @@ def comparable(left, right): @public def dtype_like(name): - @attribute.default + @attribute def dtype(self): args = getattr(self, name) args = args if util.is_iterable(args) else [args] @@ -73,7 +73,7 @@ def dtype(self): @public def shape_like(name): - @attribute.default + @attribute def shape(self): args = getattr(self, name) args = args if util.is_iterable(args) else [args] @@ -143,7 +143,7 @@ def _promote_decimal_binop(args, op): @public def numeric_like(name, op): - @attribute.default + @attribute def dtype(self): args = getattr(self, name) dtypes = [arg.dtype for arg in args] diff --git a/ibis/expr/schema.py b/ibis/expr/schema.py index 205449fdbb34..e401cb24aa13 100644 --- a/ibis/expr/schema.py +++ b/ibis/expr/schema.py @@ -53,15 +53,15 @@ def __coerce__(cls, value) -> Schema: return value return schema(value) - @attribute.default + @attribute def names(self): return tuple(self.keys()) - @attribute.default + @attribute def types(self): return tuple(self.values()) - @attribute.default + @attribute def _name_locs(self) -> dict[str, int]: return {v: i for i, v in enumerate(self.names)} diff --git a/ibis/selectors.py b/ibis/selectors.py index 956b135e5274..9d6410165685 100644 --- a/ibis/selectors.py +++ b/ibis/selectors.py @@ -600,7 +600,7 @@ def stop(self): def step(self): return self.slice.step - @attribute.default + @attribute def __precomputed_hash__(self) -> int: return hash((self.__class__, (self.start, self.stop, self.step)))