From baea1fade57c194baae2a514be58195e9f606016 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Fri, 17 Feb 2023 16:31:12 +0100 Subject: [PATCH] feat(common): add support for variadic positional and variadic keyword annotations --- ibis/common/annotations.py | 102 +++++++++++++---- ibis/common/grounds.py | 15 ++- ibis/common/tests/test_annotations.py | 69 +++++++---- ibis/common/tests/test_grounds.py | 157 ++++++++++++++++++++++++-- ibis/common/tests/test_validators.py | 8 +- 5 files changed, 288 insertions(+), 63 deletions(-) diff --git a/ibis/common/annotations.py b/ibis/common/annotations.py index f5f1a7e947e6..99324007db43 100644 --- a/ibis/common/annotations.py +++ b/ibis/common/annotations.py @@ -4,7 +4,7 @@ import inspect from typing import Any -from ibis.common.validators import Validator, any_, option +from ibis.common.validators import Validator, any_, frozendict_of, option, tuple_of from ibis.util import DotDict EMPTY = inspect.Parameter.empty # marker for missing argument @@ -70,8 +70,15 @@ def initialize(self, this): class Argument(Annotation): """Base class for all fields which should be passed as arguments.""" + __slots__ = ('_kind',) + + def __init__(self, validator=None, default=EMPTY, kind=POSITIONAL_OR_KEYWORD): + self._kind = kind + self._default = default + self._validator = validator + @classmethod - def mandatory(cls, validator=None): + def required(cls, validator=None): """Annotation to mark a mandatory argument.""" return cls(validator) @@ -89,6 +96,17 @@ def optional(cls, validator=None, default=None): validator = option(validator, default=default) return cls(validator, default=None) + @classmethod + def varargs(cls, validator=None): + """Annotation to mark a variable length positional argument.""" + validator = None if validator is None else tuple_of(validator) + return cls(validator, kind=VAR_POSITIONAL) + + @classmethod + def varkwds(cls, validator=None): + validator = None if validator is None else frozendict_of(any_, validator) + return cls(validator, kind=VAR_KEYWORD) + class Parameter(inspect.Parameter): """Augmented Parameter class to additionally hold a validator object.""" @@ -102,7 +120,7 @@ def __init__(self, name, annotation): ) super().__init__( name, - kind=POSITIONAL_OR_KEYWORD, + kind=annotation._kind, default=annotation._default, annotation=annotation._validator, ) @@ -150,22 +168,34 @@ def merge(cls, *signatures, **annotations): # mandatory fields without default values must preceed the optional # ones in the function signature, the partial ordering will be kept + var_args, var_kwargs = [], [] new_args, new_kwargs = [], [] - inherited_args, inherited_kwargs = [], [] + old_args, old_kwargs = [], [] for name, param in params.items(): - if name in inherited: + if param.kind == VAR_POSITIONAL: + var_args.append(param) + elif param.kind == VAR_KEYWORD: + var_kwargs.append(param) + elif name in inherited: if param.default is EMPTY: - inherited_args.append(param) + old_args.append(param) else: - inherited_kwargs.append(param) + old_kwargs.append(param) else: if param.default is EMPTY: new_args.append(param) else: new_kwargs.append(param) - return cls(inherited_args + new_args + new_kwargs + inherited_kwargs) + if len(var_args) > 1: + raise TypeError('only one variadic positional *args parameter is allowed') + if len(var_kwargs) > 1: + raise TypeError('only one variadic keywords **kwargs parameter is allowed') + + return cls( + old_args + new_args + var_args + new_kwargs + old_kwargs + var_kwargs + ) @classmethod def from_callable(cls, fn, validators=None, return_validator=None): @@ -199,25 +229,24 @@ def from_callable(cls, fn, validators=None, return_validator=None): parameters = [] for param in sig.parameters.values(): - if param.kind in { - VAR_POSITIONAL, - VAR_KEYWORD, - POSITIONAL_ONLY, - KEYWORD_ONLY, - }: + if param.kind in {POSITIONAL_ONLY, KEYWORD_ONLY}: raise TypeError(f"unsupported parameter kind {param.kind} in {fn}") if param.name in validators: validator = validators[param.name] - elif param.annotation is EMPTY: - validator = any_ - else: + elif param.annotation is not EMPTY: validator = Validator.from_annotation( param.annotation, module=fn.__module__ ) - - if param.default is EMPTY: - annot = Argument.mandatory(validator) + else: + validator = None + + if param.kind is VAR_POSITIONAL: + annot = Argument.varargs(validator) + elif param.kind is VAR_KEYWORD: + annot = Argument.varkwds(validator) + elif param.default is EMPTY: + annot = Argument.required(validator) else: annot = Argument.default(param.default, validator) @@ -250,7 +279,18 @@ def unbind(self, this: Any): Tuple of positional and keyword arguments. """ # does the reverse of bind, but doesn't apply defaults - return {name: getattr(this, name) for name in self.parameters} + args, kwargs = [], {} + for name, param in self.parameters.items(): + value = getattr(this, name) + if param.kind is POSITIONAL_OR_KEYWORD: + args.append(value) + elif param.kind is VAR_POSITIONAL: + args.extend(value) + elif param.kind is VAR_KEYWORD: + kwargs.update(value) + else: + raise TypeError(f"unsupported parameter kind {param.kind}") + return tuple(args), kwargs def validate(self, *args, **kwargs): """Validate the arguments against the signature. @@ -278,7 +318,16 @@ def validate(self, *args, **kwargs): param = self.parameters[name] # TODO(kszucs): provide more error context on failure this[name] = param.validate(value, this=this) + return this + def validate_nobind(self, **kwargs): + """Validate the arguments against the signature without binding.""" + this = DotDict() + for name, param in self.parameters.items(): + value = kwargs.get(name, param.default) + if value is EMPTY: + raise TypeError(f"missing required argument `{name!r}`") + this[name] = param.validate(value, this=kwargs) return this def validate_return(self, value): @@ -303,8 +352,10 @@ def validate_return(self, value): # aliases for convenience attribute = Attribute argument = Argument -mandatory = Argument.mandatory +required = Argument.required optional = Argument.optional +varargs = Argument.varargs +varkwds = Argument.varkwds default = Argument.default @@ -384,9 +435,10 @@ def annotated(_1=None, _2=None, _3=None, **kwargs): @functools.wraps(func) def wrapped(*args, **kwargs): - kwargs = sig.validate(*args, **kwargs) - result = sig.validate_return(func(**kwargs)) - return result + values = sig.validate(*args, **kwargs) + args, kwargs = sig.unbind(values) + result = func(*args, **kwargs) + return sig.validate_return(result) wrapped.__signature__ = sig diff --git a/ibis/common/grounds.py b/ibis/common/grounds.py index fc2ab72740cf..0bac85e6ab49 100644 --- a/ibis/common/grounds.py +++ b/ibis/common/grounds.py @@ -54,14 +54,14 @@ def __new__(metacls, clsname, bases, dct, **kwargs): if name in dct: dct[name] = Argument.default(dct[name], validator) else: - dct[name] = Argument.mandatory(validator) + dct[name] = Argument.required(validator) # collect the newly defined annotations slots = list(dct.pop('__slots__', [])) namespace, arguments = {}, {} for name, attrib in dct.items(): if isinstance(attrib, Validator): - attrib = Argument.mandatory(attrib) + attrib = Argument.required(attrib) if isinstance(attrib, Argument): arguments[name] = attrib @@ -96,6 +96,12 @@ def __create__(cls, *args, **kwargs) -> Annotable: kwargs = cls.__signature__.validate(*args, **kwargs) return super().__create__(**kwargs) + @classmethod + def __recreate__(cls, kwargs) -> Annotable: + # bypass signature binding by requiring keyword arguments only + kwargs = cls.__signature__.validate_nobind(**kwargs) + return super().__create__(**kwargs) + def __init__(self, **kwargs) -> None: # set the already validated arguments for name, value in kwargs.items(): @@ -221,7 +227,8 @@ def __precomputed_hash__(self): def __reduce__(self): # assuming immutability and idempotency of the __init__ method, we can # reconstruct the instance from the arguments without additional attributes - return (self.__class__, self.__args__) + state = dict(zip(self.__argnames__, self.__args__)) + return (self.__recreate__, (state,)) def __hash__(self): return self.__precomputed_hash__ @@ -240,4 +247,4 @@ def argnames(self): def copy(self, **overrides): kwargs = dict(zip(self.__argnames__, self.__args__)) kwargs.update(overrides) - return self.__class__(**kwargs) + return self.__recreate__(kwargs) diff --git a/ibis/common/tests/test_annotations.py b/ibis/common/tests/test_annotations.py index 9db79c0baf37..85c5892bf7ed 100644 --- a/ibis/common/tests/test_annotations.py +++ b/ibis/common/tests/test_annotations.py @@ -77,7 +77,7 @@ def test_parameter(): def fn(x, this): return int(x) + this['other'] - annot = Argument.mandatory(fn) + annot = Argument.required(fn) p = Parameter('test', annotation=annot) assert p.annotation is fn @@ -104,8 +104,8 @@ def to_int(x, this): def add_other(x, this): return int(x) + this['other'] - other = Parameter('other', annotation=Argument.mandatory(to_int)) - this = Parameter('this', annotation=Argument.mandatory(add_other)) + other = Parameter('other', annotation=Argument.required(to_int)) + this = Parameter('this', annotation=Argument.required(add_other)) sig = Signature(parameters=[other, this]) assert sig.validate(1, 2) == {'other': 1, 'this': 3} @@ -124,19 +124,20 @@ def test(a: int, b: int, c: int = 1): sig.validate(2, 3, "4") -def test_signature_from_callable_unsupported_argument_kinds(): - def test(a: int, b: int, *args): - pass +def test_signature_from_callable_with_varargs(): + def test(a: int, b: int, *args: int): + return a + b + sum(args) - with pytest.raises(TypeError, match="unsupported parameter kind VAR_POSITIONAL"): - Signature.from_callable(test) + sig = Signature.from_callable(test) + assert sig.validate(2, 3) == {'a': 2, 'b': 3, 'args': ()} + assert sig.validate(2, 3, 4) == {'a': 2, 'b': 3, 'args': (4,)} + assert sig.validate(2, 3, 4, 5) == {'a': 2, 'b': 3, 'args': (4, 5)} - def test(a: int, b: int, **kwargs): - pass + with pytest.raises(TypeError): + sig.validate(2, 3, 4, "5") - with pytest.raises(TypeError, match="unsupported parameter kind VAR_KEYWORD"): - Signature.from_callable(test) +def test_signature_from_callable_unsupported_argument_kinds(): def test(a: int, b: int, *, c: int): pass @@ -157,14 +158,15 @@ def to_int(x, this): def add_other(x, this): return int(x) + this['other'] - other = Parameter('other', annotation=Argument.mandatory(to_int)) - this = Parameter('this', annotation=Argument.mandatory(add_other)) + other = Parameter('other', annotation=Argument.required(to_int)) + this = Parameter('this', annotation=Argument.required(add_other)) sig = Signature(parameters=[other, this]) params = sig.validate(1, this=2) - kwargs = sig.unbind(params) - assert kwargs == {"other": 1, "this": 3} + args, kwargs = sig.unbind(params) + assert args == (1, 3) + assert kwargs == {} def as_float(x, this): @@ -175,8 +177,8 @@ def as_tuple_of_floats(x, this): return tuple(float(i) for i in x) -a = Parameter('a', annotation=Argument.mandatory(validator=as_float)) -b = Parameter('b', annotation=Argument.mandatory(validator=as_float)) +a = Parameter('a', annotation=Argument.required(validator=as_float)) +b = Parameter('b', annotation=Argument.required(validator=as_float)) c = Parameter('c', annotation=Argument.default(default=0, validator=as_float)) d = Parameter( 'd', annotation=Argument.default(default=tuple(), validator=as_tuple_of_floats) @@ -190,10 +192,11 @@ def test_signature_unbind_with_empty_variadic(d): params = sig.validate(1, 2, 3, d, e=4) assert params == {'a': 1.0, 'b': 2.0, 'c': 3.0, 'd': d, 'e': 4.0} - kwargs = sig.unbind(params) - assert kwargs == {'a': 1.0, 'b': 2.0, 'c': 3.0, 'd': d, 'e': 4.0} + args, kwargs = sig.unbind(params) + assert args == (1.0, 2.0, 3.0, tuple(map(float, d)), 4.0) + assert kwargs == {} - params_again = sig.validate(**kwargs) + params_again = sig.validate(*args, **kwargs) assert params_again == params @@ -333,3 +336,27 @@ def test(a, b, c): func(1, 2) assert func(1, 2, c=3) == 6 + + +def test_annotated_function_with_varargs(): + @annotated + def test(a: float, b: float, *args: int): + return sum((a, b) + args) + + assert test(1.0, 2.0, 3, 4) == 10.0 + assert test(1.0, 2.0, 3, 4, 5) == 15.0 + + with pytest.raises(TypeError): + test(1.0, 2.0, 3, 4, 5, 6.0) + + +def test_annotated_function_with_varkwds(): + @annotated + def test(a: float, b: float, **kwargs: int): + return sum((a, b) + tuple(kwargs.values())) + + assert test(1.0, 2.0, c=3, d=4) == 10.0 + assert test(1.0, 2.0, c=3, d=4, e=5) == 15.0 + + with pytest.raises(TypeError): + test(1.0, 2.0, c=3, d=4, e=5, f=6.0) diff --git a/ibis/common/tests/test_grounds.py b/ibis/common/tests/test_grounds.py index 58286b61bab1..84d0da85a10c 100644 --- a/ibis/common/tests/test_grounds.py +++ b/ibis/common/tests/test_grounds.py @@ -9,8 +9,10 @@ Signature, argument, attribute, - mandatory, optional, + required, + varargs, + varkwds, ) from ibis.common.caching import WeakCache from ibis.common.grounds import ( @@ -68,6 +70,19 @@ def calculated(self): return self.value + self.lower +class VariadicArgs(Concrete): + args = varargs(is_int) + + +class VariadicKeywords(Concrete): + kwargs = varkwds(is_int) + + +class VariadicArgsAndKeywords(Concrete): + args = varargs(is_int) + kwargs = varkwds(is_int) + + def test_annotable(): class InBetween(BetweenSimple): pass @@ -223,22 +238,22 @@ class IntAddClip(FloatAddClip, IntBinop): assert IntBinop.__signature__ == Signature( [ - Parameter('left', annotation=mandatory(is_int)), - Parameter('right', annotation=mandatory(is_int)), + Parameter('left', annotation=required(is_int)), + Parameter('right', annotation=required(is_int)), ] ) assert FloatAddRhs.__signature__ == Signature( [ - Parameter('left', annotation=mandatory(is_int)), - Parameter('right', annotation=mandatory(is_float)), + Parameter('left', annotation=required(is_int)), + Parameter('right', annotation=required(is_float)), ] ) assert FloatAddClip.__signature__ == Signature( [ - Parameter('left', annotation=mandatory(is_float)), - Parameter('right', annotation=mandatory(is_float)), + Parameter('left', annotation=required(is_float)), + Parameter('right', annotation=required(is_float)), Parameter('clip_lower', annotation=optional(is_int, default=0)), Parameter('clip_upper', annotation=optional(is_int, default=10)), ] @@ -246,8 +261,8 @@ class IntAddClip(FloatAddClip, IntBinop): assert IntAddClip.__signature__ == Signature( [ - Parameter('left', annotation=mandatory(is_int)), - Parameter('right', annotation=mandatory(is_int)), + Parameter('left', annotation=required(is_int)), + Parameter('right', annotation=required(is_int)), Parameter('clip_lower', annotation=optional(is_int, default=0)), Parameter('clip_upper', annotation=optional(is_int, default=10)), ] @@ -303,6 +318,124 @@ class Beta(Alpha): assert obj.e == 4 +def test_variadic_argument_reordering(): + class Test(Annotable): + a = is_int + b = is_int + args = varargs(is_int) + + class Test2(Test): + c = is_int + args = varargs(is_int) + + with pytest.raises(TypeError, match="missing a required argument: 'c'"): + Test2(1, 2) + + a = Test2(1, 2, 3) + assert a.a == 1 + assert a.b == 2 + assert a.c == 3 + assert a.args == () + + b = Test2(*range(5)) + assert b.a == 0 + assert b.b == 1 + assert b.c == 2 + assert b.args == (3, 4) + + msg = "only one variadic positional \\*args parameter is allowed" + with pytest.raises(TypeError, match=msg): + + class Test3(Test): + another_args = varargs(is_int) + + +def test_variadic_keyword_argument_reordering(): + class Test(Annotable): + a = is_int + b = is_int + options = varkwds(is_int) + + class Test2(Test): + c = is_int + options = varkwds(is_int) + + with pytest.raises(TypeError, match="missing a required argument: 'c'"): + Test2(1, 2) + + a = Test2(1, 2, c=3) + assert a.a == 1 + assert a.b == 2 + assert a.c == 3 + assert a.options == {} + + b = Test2(1, 2, c=3, d=4, e=5) + assert b.a == 1 + assert b.b == 2 + assert b.c == 3 + assert b.options == {'d': 4, 'e': 5} + + msg = "only one variadic keywords \\*\\*kwargs parameter is allowed" + with pytest.raises(TypeError, match=msg): + + class Test3(Test): + another_options = varkwds(is_int) + + +def test_variadic_argument(): + class Test(Annotable): + a = is_int + b = is_int + args = varargs(is_int) + + assert Test(1, 2).args == () + assert Test(1, 2, 3).args == (3,) + assert Test(1, 2, 3, 4, 5).args == (3, 4, 5) + + +def test_variadic_keyword_argument(): + class Test(Annotable): + first = is_int + second = is_int + options = varkwds(is_int) + + assert Test(1, 2).options == {} + assert Test(1, 2, a=3).options == {'a': 3} + assert Test(1, 2, a=3, b=4, c=5).options == {'a': 3, 'b': 4, 'c': 5} + + +def test_concrete_copy_with_variadic_argument(): + class Test(Annotable): + a = is_int + b = is_int + args = varargs(is_int) + + t = Test(1, 2, 3, 4, 5) + assert t.a == 1 + assert t.b == 2 + assert t.args == (3, 4, 5) + + u = t.copy(a=6, args=(8, 9, 10)) + assert u.a == 6 + assert u.b == 2 + assert u.args == (8, 9, 10) + + +def test_concrete_pickling_variadic_arguments(): + v = VariadicArgs(1, 2, 3, 4, 5) + assert v.args == (1, 2, 3, 4, 5) + assert_pickle_roundtrip(v) + + v = VariadicKeywords(a=3, b=4, c=5) + assert v.kwargs == {'a': 3, 'b': 4, 'c': 5} + assert_pickle_roundtrip(v) + + v = VariadicArgsAndKeywords(1, 2, 3, 4, 5, a=3, b=4, c=5) + assert v.args == (1, 2, 3, 4, 5) + assert v.kwargs == {'a': 3, 'b': 4, 'c': 5} + assert_pickle_roundtrip(v) + + def test_dont_copy_default_argument(): default = tuple() @@ -559,7 +692,7 @@ class Value(Annotable): class Reduction(Value): output_shape = "scalar" - class variadic(Value): + class Variadic(Value): @attribute.default def output_shape(self): if self.arg > 10: @@ -571,11 +704,11 @@ def output_shape(self): assert r.output_shape == "scalar" assert "output_shape" not in r.__slots__ - v = variadic(1) + v = Variadic(1) assert v.output_shape == "scalar" assert "output_shape" in v.__slots__ - v = variadic(100) + v = Variadic(100) assert v.output_shape == "columnar" assert "output_shape" in v.__slots__ diff --git a/ibis/common/tests/test_validators.py b/ibis/common/tests/test_validators.py index 0be6ed7af611..f69241ddc746 100644 --- a/ibis/common/tests/test_validators.py +++ b/ibis/common/tests/test_validators.py @@ -181,7 +181,10 @@ def test_callable_with(): def func(a, b): return str(a) + b - def func_with_kwargs(a, b, c=1): + def func_with_args(a, b, *args): + return sum((a, b) + args) + + def func_with_kwargs(a, b, c=1, **kwargs): return str(a) + b + str(c) def func_with_mandatory_kwargs(*, c): @@ -201,6 +204,9 @@ def func_with_mandatory_kwargs(*, c): with pytest.raises(TypeError, match=msg): callable_with([instance_of(int)] * 4, instance_of(str), func_with_kwargs) + wrapped = callable_with([instance_of(int)] * 4, instance_of(int), func_with_args) + assert wrapped(1, 2, 3, 4) == 10 + wrapped = callable_with( [instance_of(int), instance_of(str)], instance_of(str), func )