Skip to content

Commit

Permalink
feat(common): support Callable arguments and return types in `Valid…
Browse files Browse the repository at this point in the history
…ator.from_annotable()`
  • Loading branch information
kszucs authored and cpcloud committed Feb 17, 2023
1 parent 560474e commit ae57c36
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
7 changes: 6 additions & 1 deletion ibis/common/tests/test_validators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import sys
from typing import Dict, List, Literal, Optional, Tuple, Union
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union

import pytest
from typing_extensions import Annotated
Expand Down Expand Up @@ -110,6 +110,11 @@ def endswith_d(x, this):
(Dict[str, float], dict_of(instance_of(str), instance_of(float))),
(frozendict[str, int], frozendict_of(instance_of(str), instance_of(int))),
(Literal["alpha", "beta", "gamma"], isin(("alpha", "beta", "gamma"))),
(
Callable[[str, int], str],
callable_with((instance_of(str), instance_of(int)), instance_of(str)),
),
(Callable, instance_of(Callable)),
],
)
def test_validator_from_annotation(annot, expected):
Expand Down
20 changes: 12 additions & 8 deletions ibis/common/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,15 @@ def from_annotation(cls, annot, module=None):
(inner,) = map(cls.from_annotation, get_args(annot))
return sequence_of(inner, type=origin_type)
elif issubclass(origin_type, Mapping):
key_type, value_type = map(cls.from_annotation, get_args(annot))
return mapping_of(key_type, value_type, type=origin_type)
key_inner, value_inner = map(cls.from_annotation, get_args(annot))
return mapping_of(key_inner, value_inner, type=origin_type)
elif issubclass(origin_type, Callable):
# TODO(kszucs): add a more comprehensive callable_with rule here
return instance_of(Callable)
if args := get_args(annot):
arg_inners = map(cls.from_annotation, args[0])
return_inner = cls.from_annotation(args[1])
return callable_with(tuple(arg_inners), return_inner)
else:
return instance_of(Callable)
else:
raise NotImplementedError(
f"Cannot create validator from annotation {annot} {origin_type}"
Expand Down Expand Up @@ -264,13 +268,13 @@ def mapping_of(key_inner, value_inner, arg, *, type, **kwargs):


@validator
def callable_with(args_inner, return_inner, value, **kwargs):
def callable_with(arg_inners, return_inner, value, **kwargs):
from ibis.common.annotations import annotated

if not callable(value):
raise IbisTypeError("Argument must be a callable")

fn = annotated(args_inner, return_inner, value)
fn = annotated(arg_inners, return_inner, value)

has_varargs = False
positional, keyword_only = [], []
Expand All @@ -286,9 +290,9 @@ def callable_with(args_inner, return_inner, value, **kwargs):
raise IbisTypeError(
"Callable has mandatory keyword-only arguments which cannot be specified"
)
elif len(positional) > len(args_inner):
elif len(positional) > len(arg_inners):
raise IbisTypeError("Callable has more positional arguments than expected")
elif len(positional) < len(args_inner) and not has_varargs:
elif len(positional) < len(arg_inners) and not has_varargs:
raise IbisTypeError("Callable has less positional arguments than expected")
else:
return fn
Expand Down

0 comments on commit ae57c36

Please sign in to comment.