Skip to content

Commit

Permalink
feat(common): add support for variadic positional and variadic keywor…
Browse files Browse the repository at this point in the history
…d annotations
  • Loading branch information
kszucs authored and cpcloud committed Feb 20, 2023
1 parent b368b04 commit baea1fa
Show file tree
Hide file tree
Showing 5 changed files with 288 additions and 63 deletions.
102 changes: 77 additions & 25 deletions ibis/common/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import inspect
from typing import Any

from ibis.common.validators import Validator, any_, option
from ibis.common.validators import Validator, any_, frozendict_of, option, tuple_of
from ibis.util import DotDict

EMPTY = inspect.Parameter.empty # marker for missing argument
Expand Down Expand Up @@ -70,8 +70,15 @@ def initialize(self, this):
class Argument(Annotation):
"""Base class for all fields which should be passed as arguments."""

__slots__ = ('_kind',)

def __init__(self, validator=None, default=EMPTY, kind=POSITIONAL_OR_KEYWORD):
self._kind = kind
self._default = default
self._validator = validator

@classmethod
def mandatory(cls, validator=None):
def required(cls, validator=None):
"""Annotation to mark a mandatory argument."""
return cls(validator)

Expand All @@ -89,6 +96,17 @@ def optional(cls, validator=None, default=None):
validator = option(validator, default=default)
return cls(validator, default=None)

@classmethod
def varargs(cls, validator=None):
"""Annotation to mark a variable length positional argument."""
validator = None if validator is None else tuple_of(validator)
return cls(validator, kind=VAR_POSITIONAL)

@classmethod
def varkwds(cls, validator=None):
validator = None if validator is None else frozendict_of(any_, validator)
return cls(validator, kind=VAR_KEYWORD)


class Parameter(inspect.Parameter):
"""Augmented Parameter class to additionally hold a validator object."""
Expand All @@ -102,7 +120,7 @@ def __init__(self, name, annotation):
)
super().__init__(
name,
kind=POSITIONAL_OR_KEYWORD,
kind=annotation._kind,
default=annotation._default,
annotation=annotation._validator,
)
Expand Down Expand Up @@ -150,22 +168,34 @@ def merge(cls, *signatures, **annotations):

# mandatory fields without default values must preceed the optional
# ones in the function signature, the partial ordering will be kept
var_args, var_kwargs = [], []
new_args, new_kwargs = [], []
inherited_args, inherited_kwargs = [], []
old_args, old_kwargs = [], []

for name, param in params.items():
if name in inherited:
if param.kind == VAR_POSITIONAL:
var_args.append(param)
elif param.kind == VAR_KEYWORD:
var_kwargs.append(param)
elif name in inherited:
if param.default is EMPTY:
inherited_args.append(param)
old_args.append(param)
else:
inherited_kwargs.append(param)
old_kwargs.append(param)
else:
if param.default is EMPTY:
new_args.append(param)
else:
new_kwargs.append(param)

return cls(inherited_args + new_args + new_kwargs + inherited_kwargs)
if len(var_args) > 1:
raise TypeError('only one variadic positional *args parameter is allowed')
if len(var_kwargs) > 1:
raise TypeError('only one variadic keywords **kwargs parameter is allowed')

return cls(
old_args + new_args + var_args + new_kwargs + old_kwargs + var_kwargs
)

@classmethod
def from_callable(cls, fn, validators=None, return_validator=None):
Expand Down Expand Up @@ -199,25 +229,24 @@ def from_callable(cls, fn, validators=None, return_validator=None):

parameters = []
for param in sig.parameters.values():
if param.kind in {
VAR_POSITIONAL,
VAR_KEYWORD,
POSITIONAL_ONLY,
KEYWORD_ONLY,
}:
if param.kind in {POSITIONAL_ONLY, KEYWORD_ONLY}:
raise TypeError(f"unsupported parameter kind {param.kind} in {fn}")

if param.name in validators:
validator = validators[param.name]
elif param.annotation is EMPTY:
validator = any_
else:
elif param.annotation is not EMPTY:
validator = Validator.from_annotation(
param.annotation, module=fn.__module__
)

if param.default is EMPTY:
annot = Argument.mandatory(validator)
else:
validator = None

if param.kind is VAR_POSITIONAL:
annot = Argument.varargs(validator)
elif param.kind is VAR_KEYWORD:
annot = Argument.varkwds(validator)
elif param.default is EMPTY:
annot = Argument.required(validator)
else:
annot = Argument.default(param.default, validator)

Expand Down Expand Up @@ -250,7 +279,18 @@ def unbind(self, this: Any):
Tuple of positional and keyword arguments.
"""
# does the reverse of bind, but doesn't apply defaults
return {name: getattr(this, name) for name in self.parameters}
args, kwargs = [], {}
for name, param in self.parameters.items():
value = getattr(this, name)
if param.kind is POSITIONAL_OR_KEYWORD:
args.append(value)
elif param.kind is VAR_POSITIONAL:
args.extend(value)
elif param.kind is VAR_KEYWORD:
kwargs.update(value)
else:
raise TypeError(f"unsupported parameter kind {param.kind}")
return tuple(args), kwargs

def validate(self, *args, **kwargs):
"""Validate the arguments against the signature.
Expand Down Expand Up @@ -278,7 +318,16 @@ def validate(self, *args, **kwargs):
param = self.parameters[name]
# TODO(kszucs): provide more error context on failure
this[name] = param.validate(value, this=this)
return this

def validate_nobind(self, **kwargs):
"""Validate the arguments against the signature without binding."""
this = DotDict()
for name, param in self.parameters.items():
value = kwargs.get(name, param.default)
if value is EMPTY:
raise TypeError(f"missing required argument `{name!r}`")
this[name] = param.validate(value, this=kwargs)
return this

def validate_return(self, value):
Expand All @@ -303,8 +352,10 @@ def validate_return(self, value):
# aliases for convenience
attribute = Attribute
argument = Argument
mandatory = Argument.mandatory
required = Argument.required
optional = Argument.optional
varargs = Argument.varargs
varkwds = Argument.varkwds
default = Argument.default


Expand Down Expand Up @@ -384,9 +435,10 @@ def annotated(_1=None, _2=None, _3=None, **kwargs):

@functools.wraps(func)
def wrapped(*args, **kwargs):
kwargs = sig.validate(*args, **kwargs)
result = sig.validate_return(func(**kwargs))
return result
values = sig.validate(*args, **kwargs)
args, kwargs = sig.unbind(values)
result = func(*args, **kwargs)
return sig.validate_return(result)

wrapped.__signature__ = sig

Expand Down
15 changes: 11 additions & 4 deletions ibis/common/grounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ def __new__(metacls, clsname, bases, dct, **kwargs):
if name in dct:
dct[name] = Argument.default(dct[name], validator)
else:
dct[name] = Argument.mandatory(validator)
dct[name] = Argument.required(validator)

# collect the newly defined annotations
slots = list(dct.pop('__slots__', []))
namespace, arguments = {}, {}
for name, attrib in dct.items():
if isinstance(attrib, Validator):
attrib = Argument.mandatory(attrib)
attrib = Argument.required(attrib)

if isinstance(attrib, Argument):
arguments[name] = attrib
Expand Down Expand Up @@ -96,6 +96,12 @@ def __create__(cls, *args, **kwargs) -> Annotable:
kwargs = cls.__signature__.validate(*args, **kwargs)
return super().__create__(**kwargs)

@classmethod
def __recreate__(cls, kwargs) -> Annotable:
# bypass signature binding by requiring keyword arguments only
kwargs = cls.__signature__.validate_nobind(**kwargs)
return super().__create__(**kwargs)

def __init__(self, **kwargs) -> None:
# set the already validated arguments
for name, value in kwargs.items():
Expand Down Expand Up @@ -221,7 +227,8 @@ def __precomputed_hash__(self):
def __reduce__(self):
# assuming immutability and idempotency of the __init__ method, we can
# reconstruct the instance from the arguments without additional attributes
return (self.__class__, self.__args__)
state = dict(zip(self.__argnames__, self.__args__))
return (self.__recreate__, (state,))

def __hash__(self):
return self.__precomputed_hash__
Expand All @@ -240,4 +247,4 @@ def argnames(self):
def copy(self, **overrides):
kwargs = dict(zip(self.__argnames__, self.__args__))
kwargs.update(overrides)
return self.__class__(**kwargs)
return self.__recreate__(kwargs)
69 changes: 48 additions & 21 deletions ibis/common/tests/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_parameter():
def fn(x, this):
return int(x) + this['other']

annot = Argument.mandatory(fn)
annot = Argument.required(fn)
p = Parameter('test', annotation=annot)

assert p.annotation is fn
Expand All @@ -104,8 +104,8 @@ def to_int(x, this):
def add_other(x, this):
return int(x) + this['other']

other = Parameter('other', annotation=Argument.mandatory(to_int))
this = Parameter('this', annotation=Argument.mandatory(add_other))
other = Parameter('other', annotation=Argument.required(to_int))
this = Parameter('this', annotation=Argument.required(add_other))

sig = Signature(parameters=[other, this])
assert sig.validate(1, 2) == {'other': 1, 'this': 3}
Expand All @@ -124,19 +124,20 @@ def test(a: int, b: int, c: int = 1):
sig.validate(2, 3, "4")


def test_signature_from_callable_unsupported_argument_kinds():
def test(a: int, b: int, *args):
pass
def test_signature_from_callable_with_varargs():
def test(a: int, b: int, *args: int):
return a + b + sum(args)

with pytest.raises(TypeError, match="unsupported parameter kind VAR_POSITIONAL"):
Signature.from_callable(test)
sig = Signature.from_callable(test)
assert sig.validate(2, 3) == {'a': 2, 'b': 3, 'args': ()}
assert sig.validate(2, 3, 4) == {'a': 2, 'b': 3, 'args': (4,)}
assert sig.validate(2, 3, 4, 5) == {'a': 2, 'b': 3, 'args': (4, 5)}

def test(a: int, b: int, **kwargs):
pass
with pytest.raises(TypeError):
sig.validate(2, 3, 4, "5")

with pytest.raises(TypeError, match="unsupported parameter kind VAR_KEYWORD"):
Signature.from_callable(test)

def test_signature_from_callable_unsupported_argument_kinds():
def test(a: int, b: int, *, c: int):
pass

Expand All @@ -157,14 +158,15 @@ def to_int(x, this):
def add_other(x, this):
return int(x) + this['other']

other = Parameter('other', annotation=Argument.mandatory(to_int))
this = Parameter('this', annotation=Argument.mandatory(add_other))
other = Parameter('other', annotation=Argument.required(to_int))
this = Parameter('this', annotation=Argument.required(add_other))

sig = Signature(parameters=[other, this])
params = sig.validate(1, this=2)

kwargs = sig.unbind(params)
assert kwargs == {"other": 1, "this": 3}
args, kwargs = sig.unbind(params)
assert args == (1, 3)
assert kwargs == {}


def as_float(x, this):
Expand All @@ -175,8 +177,8 @@ def as_tuple_of_floats(x, this):
return tuple(float(i) for i in x)


a = Parameter('a', annotation=Argument.mandatory(validator=as_float))
b = Parameter('b', annotation=Argument.mandatory(validator=as_float))
a = Parameter('a', annotation=Argument.required(validator=as_float))
b = Parameter('b', annotation=Argument.required(validator=as_float))
c = Parameter('c', annotation=Argument.default(default=0, validator=as_float))
d = Parameter(
'd', annotation=Argument.default(default=tuple(), validator=as_tuple_of_floats)
Expand All @@ -190,10 +192,11 @@ def test_signature_unbind_with_empty_variadic(d):
params = sig.validate(1, 2, 3, d, e=4)
assert params == {'a': 1.0, 'b': 2.0, 'c': 3.0, 'd': d, 'e': 4.0}

kwargs = sig.unbind(params)
assert kwargs == {'a': 1.0, 'b': 2.0, 'c': 3.0, 'd': d, 'e': 4.0}
args, kwargs = sig.unbind(params)
assert args == (1.0, 2.0, 3.0, tuple(map(float, d)), 4.0)
assert kwargs == {}

params_again = sig.validate(**kwargs)
params_again = sig.validate(*args, **kwargs)
assert params_again == params


Expand Down Expand Up @@ -333,3 +336,27 @@ def test(a, b, c):
func(1, 2)

assert func(1, 2, c=3) == 6


def test_annotated_function_with_varargs():
@annotated
def test(a: float, b: float, *args: int):
return sum((a, b) + args)

assert test(1.0, 2.0, 3, 4) == 10.0
assert test(1.0, 2.0, 3, 4, 5) == 15.0

with pytest.raises(TypeError):
test(1.0, 2.0, 3, 4, 5, 6.0)


def test_annotated_function_with_varkwds():
@annotated
def test(a: float, b: float, **kwargs: int):
return sum((a, b) + tuple(kwargs.values()))

assert test(1.0, 2.0, c=3, d=4) == 10.0
assert test(1.0, 2.0, c=3, d=4, e=5) == 15.0

with pytest.raises(TypeError):
test(1.0, 2.0, c=3, d=4, e=5, f=6.0)
Loading

0 comments on commit baea1fa

Please sign in to comment.