Skip to content

Commit

Permalink
feat(common): support positional only and keyword only arguments in a…
Browse files Browse the repository at this point in the history
…nnotations
  • Loading branch information
kszucs authored and cpcloud committed Feb 20, 2023
1 parent baea1fa commit 340dca1
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 29 deletions.
36 changes: 20 additions & 16 deletions ibis/common/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,23 +78,23 @@ def __init__(self, validator=None, default=EMPTY, kind=POSITIONAL_OR_KEYWORD):
self._validator = validator

@classmethod
def required(cls, validator=None):
def required(cls, validator=None, kind=POSITIONAL_OR_KEYWORD):
"""Annotation to mark a mandatory argument."""
return cls(validator)
return cls(validator=validator, kind=kind)

@classmethod
def default(cls, default, validator=None):
def default(cls, default, validator=None, kind=POSITIONAL_OR_KEYWORD):
"""Annotation to allow missing arguments with a default value."""
return cls(validator, default=default)
return cls(validator=validator, default=default, kind=kind)

@classmethod
def optional(cls, validator=None, default=None):
def optional(cls, validator=None, default=None, kind=POSITIONAL_OR_KEYWORD):
"""Annotation to allow and treat `None` values as missing arguments."""
if validator is None:
validator = option(any_, default=default)
else:
validator = option(validator, default=default)
return cls(validator, default=None)
return cls(validator=validator, default=None, kind=kind)

@classmethod
def varargs(cls, validator=None):
Expand Down Expand Up @@ -174,8 +174,12 @@ def merge(cls, *signatures, **annotations):

for name, param in params.items():
if param.kind == VAR_POSITIONAL:
if var_args:
raise TypeError('only one variadic *args parameter is allowed')
var_args.append(param)
elif param.kind == VAR_KEYWORD:
if var_kwargs:
raise TypeError('only one variadic **kwargs parameter is allowed')
var_kwargs.append(param)
elif name in inherited:
if param.default is EMPTY:
Expand All @@ -188,11 +192,6 @@ def merge(cls, *signatures, **annotations):
else:
new_kwargs.append(param)

if len(var_args) > 1:
raise TypeError('only one variadic positional *args parameter is allowed')
if len(var_kwargs) > 1:
raise TypeError('only one variadic keywords **kwargs parameter is allowed')

return cls(
old_args + new_args + var_args + new_kwargs + old_kwargs + var_kwargs
)
Expand Down Expand Up @@ -229,9 +228,6 @@ def from_callable(cls, fn, validators=None, return_validator=None):

parameters = []
for param in sig.parameters.values():
if param.kind in {POSITIONAL_ONLY, KEYWORD_ONLY}:
raise TypeError(f"unsupported parameter kind {param.kind} in {fn}")

if param.name in validators:
validator = validators[param.name]
elif param.annotation is not EMPTY:
Expand All @@ -246,9 +242,9 @@ def from_callable(cls, fn, validators=None, return_validator=None):
elif param.kind is VAR_KEYWORD:
annot = Argument.varkwds(validator)
elif param.default is EMPTY:
annot = Argument.required(validator)
annot = Argument.required(validator, kind=param.kind)
else:
annot = Argument.default(param.default, validator)
annot = Argument.default(param.default, validator, kind=param.kind)

parameters.append(Parameter(param.name, annot))

Expand Down Expand Up @@ -288,6 +284,10 @@ def unbind(self, this: Any):
args.extend(value)
elif param.kind is VAR_KEYWORD:
kwargs.update(value)
elif param.kind is KEYWORD_ONLY:
kwargs[name] = value
elif param.kind is POSITIONAL_ONLY:
args.append(value)
else:
raise TypeError(f"unsupported parameter kind {param.kind}")
return tuple(args), kwargs
Expand Down Expand Up @@ -435,9 +435,13 @@ def annotated(_1=None, _2=None, _3=None, **kwargs):

@functools.wraps(func)
def wrapped(*args, **kwargs):
# 1. Validate the passed arguments
values = sig.validate(*args, **kwargs)
# 2. Reconstruction of the original arguments
args, kwargs = sig.unbind(values)
# 3. Call the function with the validated arguments
result = func(*args, **kwargs)
# 4. Validate the return value
return sig.validate_return(result)

wrapped.__signature__ = sig
Expand Down
48 changes: 39 additions & 9 deletions ibis/common/tests/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def test(a: int, b: int, c: int = 1):
with pytest.raises(TypeError):
sig.validate(2, 3, "4")

args, kwargs = sig.unbind(sig.validate(2, 3))
assert args == (2, 3, 1)
assert kwargs == {}


def test_signature_from_callable_with_varargs():
def test(a: int, b: int, *args: int):
Expand All @@ -136,19 +140,45 @@ def test(a: int, b: int, *args: int):
with pytest.raises(TypeError):
sig.validate(2, 3, 4, "5")

args, kwargs = sig.unbind(sig.validate(2, 3, 4, 5))
assert args == (2, 3, 4, 5)
assert kwargs == {}

def test_signature_from_callable_unsupported_argument_kinds():
def test(a: int, b: int, *, c: int):
pass

with pytest.raises(TypeError, match="unsupported parameter kind KEYWORD_ONLY"):
Signature.from_callable(test)

def test_signature_from_callable_with_positional_only_arguments():
def test(a: int, b: int, /, c: int = 1):
pass
return a + b + c

sig = Signature.from_callable(test)
assert sig.validate(2, 3) == {'a': 2, 'b': 3, 'c': 1}
assert sig.validate(2, 3, 4) == {'a': 2, 'b': 3, 'c': 4}
assert sig.validate(2, 3, c=4) == {'a': 2, 'b': 3, 'c': 4}

msg = "'b' parameter is positional only, but was passed as a keyword"
with pytest.raises(TypeError, match=msg):
sig.validate(1, b=2)

args, kwargs = sig.unbind(sig.validate(2, 3))
assert args == (2, 3, 1)
assert kwargs == {}


def test_signature_from_callable_with_keyword_only_arguments():
def test(a: int, b: int, *, c: float, d: float = 0.0):
return a + b + c

sig = Signature.from_callable(test)
assert sig.validate(2, 3, c=4.0) == {'a': 2, 'b': 3, 'c': 4.0, 'd': 0.0}
assert sig.validate(2, 3, c=4.0, d=5.0) == {'a': 2, 'b': 3, 'c': 4.0, 'd': 5.0}

with pytest.raises(TypeError, match="missing a required argument: 'c'"):
sig.validate(2, 3)
with pytest.raises(TypeError, match="too many positional arguments"):
sig.validate(2, 3, 4)

with pytest.raises(TypeError, match="unsupported parameter kind POSITIONAL_ONLY"):
Signature.from_callable(test)
args, kwargs = sig.unbind(sig.validate(2, 3, c=4.0))
assert args == (2, 3)
assert kwargs == {'c': 4.0, 'd': 0.0}


def test_signature_unbind():
Expand Down
4 changes: 2 additions & 2 deletions ibis/common/tests/test_grounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ class Test2(Test):
assert b.c == 2
assert b.args == (3, 4)

msg = "only one variadic positional \\*args parameter is allowed"
msg = "only one variadic \\*args parameter is allowed"
with pytest.raises(TypeError, match=msg):

class Test3(Test):
Expand Down Expand Up @@ -375,7 +375,7 @@ class Test2(Test):
assert b.c == 3
assert b.options == {'d': 4, 'e': 5}

msg = "only one variadic keywords \\*\\*kwargs parameter is allowed"
msg = "only one variadic \\*\\*kwargs parameter is allowed"
with pytest.raises(TypeError, match=msg):

class Test3(Test):
Expand Down
6 changes: 4 additions & 2 deletions ibis/common/tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,12 @@ def func_with_kwargs(a, b, c=1, **kwargs):
def func_with_mandatory_kwargs(*, c):
return c

with pytest.raises(TypeError, match="Argument must be a callable"):
msg = "Argument must be a callable"
with pytest.raises(TypeError, match=msg):
callable_with([instance_of(int), instance_of(str)], 10, "string")

with pytest.raises(TypeError, match="unsupported parameter kind KEYWORD_ONLY"):
msg = "Callable has mandatory keyword-only arguments which cannot be specified"
with pytest.raises(TypeError, match=msg):
callable_with([instance_of(int)], instance_of(str), func_with_mandatory_kwargs)

msg = "Callable has more positional arguments than expected"
Expand Down

0 comments on commit 340dca1

Please sign in to comment.