From f95613a5a462117385679857a887d2aa363cd011 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Sat, 19 Aug 2023 17:49:22 +0200 Subject: [PATCH] refactor(common): improve error messages raised during validation --- ibis/backends/tests/test_generic.py | 2 +- ibis/common/annotations.py | 217 +++++++++++++++--- ibis/common/grounds.py | 24 +- ibis/common/patterns.py | 197 ++++++++++------ .../error.txt | 3 + .../missing_a_required_argument.txt | 3 + .../too_many_positional_arguments.txt | 3 + .../parameter_is_positional_only.txt | 3 + .../test_error_message/error_message.txt | 8 + .../error_message_py311.txt | 8 + ibis/common/tests/test_annotations.py | 143 +++++++----- ibis/common/tests/test_grounds.py | 29 ++- ibis/common/tests/test_patterns.py | 103 +++++++-- ibis/common/typing.py | 14 ++ ibis/expr/datatypes/core.py | 9 +- ibis/expr/datatypes/tests/test_core.py | 6 - ibis/expr/operations/core.py | 2 +- .../missing_a_required_argument.txt | 3 + .../too_many_positional_arguments.txt | 3 + .../got_an_unexpected_keyword.txt | 3 + .../multiple_values_for_argument.txt | 3 + .../call4-invalid_dtype/invalid_dtype.txt | 4 + .../unable_to_normalize.txt | 1 + ibis/expr/operations/tests/test_core.py | 4 +- ibis/expr/operations/tests/test_generic.py | 21 ++ ibis/tests/expr/test_table.py | 2 +- ibis/tests/expr/test_window_frames.py | 4 +- 27 files changed, 626 insertions(+), 196 deletions(-) create mode 100644 ibis/common/tests/snapshots/test_annotations/test_annotated_function_without_decoration/error.txt create mode 100644 ibis/common/tests/snapshots/test_annotations/test_signature_from_callable_with_keyword_only_arguments/missing_a_required_argument.txt create mode 100644 ibis/common/tests/snapshots/test_annotations/test_signature_from_callable_with_keyword_only_arguments/too_many_positional_arguments.txt create mode 100644 ibis/common/tests/snapshots/test_annotations/test_signature_from_callable_with_positional_only_arguments/parameter_is_positional_only.txt create mode 100644 ibis/common/tests/snapshots/test_grounds/test_error_message/error_message.txt create mode 100644 ibis/common/tests/snapshots/test_grounds/test_error_message/error_message_py311.txt create mode 100644 ibis/expr/operations/tests/snapshots/test_generic/test_error_message_when_constructing_literal/call0-missing_a_required_argument/missing_a_required_argument.txt create mode 100644 ibis/expr/operations/tests/snapshots/test_generic/test_error_message_when_constructing_literal/call1-too_many_positional_arguments/too_many_positional_arguments.txt create mode 100644 ibis/expr/operations/tests/snapshots/test_generic/test_error_message_when_constructing_literal/call2-got_an_unexpected_keyword/got_an_unexpected_keyword.txt create mode 100644 ibis/expr/operations/tests/snapshots/test_generic/test_error_message_when_constructing_literal/call3-multiple_values_for_argument/multiple_values_for_argument.txt create mode 100644 ibis/expr/operations/tests/snapshots/test_generic/test_error_message_when_constructing_literal/call4-invalid_dtype/invalid_dtype.txt create mode 100644 ibis/expr/operations/tests/snapshots/test_generic/test_error_message_when_constructing_literal/call5-unable_to_normalize/unable_to_normalize.txt diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index c7d0993cca5c..2a95cb2c34cc 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -482,7 +482,7 @@ def test_dropna_invalid(alltypes): ): alltypes.dropna(subset=["invalid_col"]) - with pytest.raises(ValidationError, match=r"'invalid' doesn't match"): + with pytest.raises(ValidationError): alltypes.dropna(how="invalid") diff --git a/ibis/common/annotations.py b/ibis/common/annotations.py index d5a234d3d7d4..1dbec9f47070 100644 --- a/ibis/common/annotations.py +++ b/ibis/common/annotations.py @@ -3,6 +3,7 @@ import functools import inspect import types +from typing import TYPE_CHECKING, Callable from typing import Any as AnyType from ibis.common.bases import Immutable, Slotted @@ -15,7 +16,10 @@ TupleOf, ) from ibis.common.patterns import pattern as ensure_pattern -from ibis.common.typing import get_type_hints +from ibis.common.typing import format_typehint, get_type_hints + +if TYPE_CHECKING: + from collections.abc import Sequence EMPTY = inspect.Parameter.empty # marker for missing argument KEYWORD_ONLY = inspect.Parameter.KEYWORD_ONLY @@ -29,7 +33,65 @@ class ValidationError(Exception): - ... + __slots__ = () + + +class AttributeValidationError(ValidationError): + __slots__ = ("name", "value", "pattern") + + def __init__(self, name: str, value: AnyType, pattern: Pattern): + self.name = name + self.value = value + self.pattern = pattern + + def __str__(self): + return f"Failed to validate attribute `{self.name}`: {self.value!r} is not {self.pattern.describe()}" + + +class ReturnValidationError(ValidationError): + __slots__ = ("func", "value", "pattern") + + def __init__(self, func: Callable, value: AnyType, pattern: Pattern): + self.func = func + self.value = value + self.pattern = pattern + + def __str__(self): + return f"Failed to validate return value of `{self.func.__name__}`: {self.value!r} is not {self.pattern.describe()}" + + +class SignatureValidationError(ValidationError): + __slots__ = ("msg", "sig", "func", "args", "kwargs", "errors") + + def __init__( + self, + msg: str, + sig: Signature, + func: Callable, + args: tuple[AnyType, ...], + kwargs: dict[str, AnyType], + errors: Sequence[tuple[str, AnyType, Pattern]] = (), + ): + self.msg = msg + self.sig = sig + self.func = func + self.args = args + self.kwargs = kwargs + self.errors = errors + + def __str__(self): + args = tuple(repr(arg) for arg in self.args) + args += tuple(f"{k}={v!r}" for k, v in self.kwargs.items()) + call = f"{self.func.__name__}({', '.join(args)})" + + errors = "" + for name, value, pattern in self.errors: + errors += f"\n `{name}`: {value!r} is not {pattern.describe()}" + + sig = f"{self.func.__name__}{self.sig}" + cause = str(self.__cause__) if self.__cause__ else "" + + return self.msg.format(sig=sig, call=call, cause=cause, errors=errors) class Annotation(Slotted, Immutable): @@ -40,11 +102,29 @@ class Annotation(Slotted, Immutable): __slots__ = () - def validate(self, arg, context=None): - result = self.pattern.match(arg, context) - if result is NoMatch: - raise ValidationError(f"{arg!r} doesn't match {self.pattern!r}") + def validate(self, name: str, value: AnyType, this: AnyType) -> AnyType: + """Validate the field. + + Parameters + ---------- + name + The name of the attribute. + value + The value of the attribute. + this + The instance of the class the attribute is defined on. + Returns + ------- + The validated value for the field. + """ + result = self.pattern.match(value, this) + if result is NoMatch: + raise AttributeValidationError( + name=name, + value=value, + pattern=self.pattern, + ) return result @@ -69,11 +149,22 @@ class Attribute(Annotation): def __init__(self, pattern: Pattern = _any, default: AnyType = EMPTY): super().__init__(pattern=ensure_pattern(pattern), default=default) - def initialize(self, this: AnyType) -> AnyType: - """Compute the default value of the field. + def has_default(self): + """Check if the field has a default value. + + Returns + ------- + bool + """ + return self.default is not EMPTY + + def get_default(self, name: str, this: AnyType) -> AnyType: + """Get the default value of the field. Parameters ---------- + name + The name of the attribute. this The instance of the class the attribute is defined on. @@ -81,13 +172,11 @@ def initialize(self, this: AnyType) -> AnyType: ------- The default value for the field. """ - if self.default is EMPTY: - return EMPTY - elif callable(self.default): + if callable(self.default): value = self.default(this) else: value = self.default - return self.validate(value, this) + return self.validate(name, value, this) def __call__(self, default): """Needed to support the decorator syntax.""" @@ -180,6 +269,26 @@ def __init__(self, name, annotation): annotation=annotation, ) + def __str__(self): + formatted = self._name + + if self._annotation is not EMPTY: + typehint = format_typehint(self._annotation.typehint) + formatted = f"{formatted}: {typehint}" + + if self._default is not EMPTY: + if self._annotation is not EMPTY: + formatted = f"{formatted} = {self._default!r}" + else: + formatted = f"{formatted}={self._default!r}" + + if self._kind == VAR_POSITIONAL: + formatted = "*" + formatted + elif self._kind == VAR_KEYWORD: + formatted = "**" + formatted + + return formatted + class Signature(inspect.Signature): """Validatable signature. @@ -339,11 +448,13 @@ def unbind(self, this: dict[str, Any]) -> tuple[tuple[Any, ...], dict[str, Any]] raise TypeError(f"unsupported parameter kind {param.kind}") return tuple(args), kwargs - def validate(self, *args, **kwargs): + def validate(self, func, args, kwargs): """Validate the arguments against the signature. Parameters ---------- + func : Callable + Callable to validate the arguments for. args : tuple Positional arguments. kwargs : dict @@ -354,39 +465,77 @@ def validate(self, *args, **kwargs): validated : dict Dictionary of validated arguments. """ - # bind the signature to the passed arguments and apply the patterns - # before passing the arguments, so self.__init__() receives already - # validated arguments as keywords - bound = self.bind(*args, **kwargs) - bound.apply_defaults() - - this = {} + try: + bound = self.bind(*args, **kwargs) + bound.apply_defaults() + except TypeError as err: + raise SignatureValidationError( + "{call} {cause}\n\nExpected signature: {sig}", + sig=self, + func=func, + args=args, + kwargs=kwargs, + ) from err + + this, errors = {}, [] for name, value in bound.arguments.items(): param = self.parameters[name] - # TODO(kszucs): provide more error context on failure - this[name] = param.annotation.validate(value, this) + pattern = param.annotation.pattern + + result = pattern.match(value, this) + if result is NoMatch: + errors.append((name, value, pattern)) + else: + this[name] = result + + if errors: + raise SignatureValidationError( + "{call} has failed due to the following errors:{errors}\n\nExpected signature: {sig}", + sig=self, + func=func, + args=args, + kwargs=kwargs, + errors=errors, + ) return this - def validate_nobind(self, **kwargs): + def validate_nobind(self, func, kwargs): """Validate the arguments against the signature without binding.""" - this = {} + this, errors = {}, [] for name, param in self.parameters.items(): value = kwargs.get(name, param.default) if value is EMPTY: raise TypeError(f"missing required argument `{name!r}`") - this[name] = param.annotation.validate(value, kwargs) + + pattern = param.annotation.pattern + result = pattern.match(value, this) + if result is NoMatch: + errors.append((name, value, pattern)) + else: + this[name] = result + + if errors: + raise SignatureValidationError( + "{call} has failed due to the following errors:{errors}\n\nExpected signature: {sig}", + sig=self, + func=func, + args=(), + kwargs=kwargs, + errors=errors, + ) + return this - def validate_return(self, value, context): + def validate_return(self, func, value): """Validate the return value of a function. Parameters ---------- + func : Callable + Callable to validate the return value for. value : Any Return value of the function. - context : dict - Context dictionary. Returns ------- @@ -396,9 +545,13 @@ def validate_return(self, value, context): if self.return_annotation is EMPTY: return value - result = self.return_annotation.match(value, context) + result = self.return_annotation.match(value, {}) if result is NoMatch: - raise ValidationError(f"{value!r} doesn't match {self}") + raise ReturnValidationError( + func=func, + value=value, + pattern=self.return_annotation, + ) return result @@ -476,13 +629,13 @@ def annotated(_1=None, _2=None, _3=None, **kwargs): @functools.wraps(func) def wrapped(*args, **kwargs): # 1. Validate the passed arguments - values = sig.validate(*args, **kwargs) + values = sig.validate(func, args, kwargs) # 2. Reconstruction of the original arguments args, kwargs = sig.unbind(values) # 3. Call the function with the validated arguments result = func(*args, **kwargs) # 4. Validate the return value - return sig.validate_return(result, {}) + return sig.validate_return(func, result) wrapped.__signature__ = sig diff --git a/ibis/common/grounds.py b/ibis/common/grounds.py index 8f7fb6dbacfd..91994108c4a3 100644 --- a/ibis/common/grounds.py +++ b/ibis/common/grounds.py @@ -12,7 +12,6 @@ from typing_extensions import Self, dataclass_transform from ibis.common.annotations import ( - EMPTY, Annotation, Argument, Attribute, @@ -115,14 +114,14 @@ class Annotable(Base, metaclass=AnnotableMeta): @classmethod def __create__(cls, *args: Any, **kwargs: Any) -> Self: - # construct the instance by passing the validated keyword arguments - kwargs = cls.__signature__.validate(*args, **kwargs) + # construct the instance by passing only validated keyword arguments + kwargs = cls.__signature__.validate(cls, args, kwargs) return super().__create__(**kwargs) @classmethod def __recreate__(cls, kwargs: Any) -> Self: # bypass signature binding by requiring keyword arguments only - kwargs = cls.__signature__.validate_nobind(**kwargs) + kwargs = cls.__signature__.validate_nobind(cls, kwargs) return super().__create__(**kwargs) def __init__(self, **kwargs: Any) -> None: @@ -131,16 +130,17 @@ def __init__(self, **kwargs: Any) -> None: object.__setattr__(self, name, value) # initialize the remaining attributes for name, field in self.__attributes__.items(): - if (default := field.initialize(self)) is not EMPTY: - object.__setattr__(self, name, default) + if field.has_default(): + object.__setattr__(self, name, field.get_default(name, self)) def __setattr__(self, name, value) -> None: # first try to look up the argument then the attribute if param := self.__signature__.parameters.get(name): - value = param.annotation.validate(value, self) - elif field := self.__attributes__.get(name): - value = field.validate(value, self) - super().__setattr__(name, value) + value = param.annotation.validate(name, value, self) + # then try to look up the attribute + elif annot := self.__attributes__.get(name): + value = annot.validate(name, value, self) + return super().__setattr__(name, value) def __repr__(self) -> str: args = (f"{n}={getattr(self, n)!r}" for n in self.__argnames__) @@ -204,8 +204,8 @@ def __init__(self, **kwargs: Any) -> None: # initialize the remaining attributes for name, field in self.__attributes__.items(): - if (default := field.initialize(self)) is not EMPTY: - object.__setattr__(self, name, default) + if field.has_default(): + object.__setattr__(self, name, field.get_default(name, self)) def __reduce__(self): # assuming immutability and idempotency of the __init__ method, we can diff --git a/ibis/common/patterns.py b/ibis/common/patterns.py index 3168cc833453..227890aa67fc 100644 --- a/ibis/common/patterns.py +++ b/ibis/common/patterns.py @@ -12,7 +12,7 @@ from typing import ( Annotated, ForwardRef, - Generic, # noqa: F401 + Generic, Literal, Optional, TypeVar, @@ -32,6 +32,7 @@ Sentinel, UnionType, _ClassInfo, + format_typehint, get_bound_typevars, get_type_params, ) @@ -128,6 +129,8 @@ def from_typehint(cls, annot: type, allow_coercion: bool = True) -> Pattern: return LazyInstanceOf(annot) else: raise TypeError(f"Cannot create validator from annotation {annot!r}") + elif origin is CoercedTo: + return CoercedTo(args[0]) elif origin is Literal: # for literal types we check the value against the literal values return IsIn(args) @@ -237,6 +240,9 @@ def is_match(self, value: AnyType, context: dict[str, AnyType]) -> bool: """ return self.match(value, context) is not NoMatch + def describe(self, plural=False): + return "matching {self!r}" + @abstractmethod def __eq__(self, other: Pattern) -> bool: ... @@ -504,20 +510,6 @@ def namespace(cls, module) -> Namespace: _ = Variable("_") -class Always(Slotted, Singleton, Pattern): - """Pattern that matches everything.""" - - def match(self, value, context): - return value - - -class Never(Slotted, Singleton, Pattern): - """Pattern that matches nothing.""" - - def match(self, value, context): - return NoMatch - - class Is(Slotted, Pattern): """Pattern that matches a value against a reference value. @@ -547,6 +539,13 @@ def match(self, value, context): _any = Any() +class Nothing(Slotted, Singleton, Pattern): + """Pattern that no values.""" + + def match(self, value, context): + return NoMatch + + class Capture(Slotted, Pattern): """Pattern that captures a value in the context. @@ -617,6 +616,12 @@ def __init__(self, predicate): assert callable(predicate) super().__init__(predicate=predicate) + def describe(self, plural=False): + if plural: + return f"values that satisfy {self.predicate.__name__}()" + else: + return f"a value that satisfies {self.predicate.__name__}()" + def match(self, value, context): if self.predicate(value): return value @@ -667,7 +672,7 @@ class Namespace: InstanceOf(type=) >>> >>> ns.Negate(5) - Object(type=CoercedTo(target=), args=(EqualTo(value=5),), kwargs=FrozenDict({})) + Object(type=CoercedTo(target=, func=>), args=(EqualTo(value=5),), kwargs=FrozenDict({})) """ __slots__ = ("module", "pattern") @@ -684,39 +689,6 @@ def __getattr__(self, name: str) -> Pattern: return self.pattern(getattr(self.module, name)) -class Apply(Slotted, Pattern): - """Pattern that applies a function to the value. - - The function must accept a single argument. - - Parameters - ---------- - func - The function to apply. - - Examples - -------- - >>> from ibis.common.patterns import Apply, match - >>> - >>> match("a" @ Apply(lambda x: x + 1), 5) - 6 - """ - - __slots__ = ("func",) - func: Callable - - def __init__(self, func): - assert callable(func) - super().__init__(func=func) - - def match(self, value, context): - return self.func(value) - - def __call__(self, *args, **kwargs): - """Convenience method to create a Call pattern.""" - return Call(self.func, *args, **kwargs) - - class EqualTo(Slotted, Pattern): """Pattern that checks a value equals to the given value. @@ -738,6 +710,9 @@ def match(self, value, context): else: return NoMatch + def describe(self, plural=False): + return repr(self.value) + class Option(Slotted, Pattern): """Pattern that matches `None` or a value that passes the inner validator. @@ -755,6 +730,12 @@ class Option(Slotted, Pattern): def __init__(self, pat, default=None): super().__init__(pattern=pattern(pat), default=default) + def describe(self, plural=False): + if plural: + return f"optional {self.pattern.describe(plural=True)}" + else: + return f"either None or {self.pattern.describe(plural=False)}" + def match(self, value, context): if value is None: if self.default is None: @@ -765,6 +746,22 @@ def match(self, value, context): return self.pattern.match(value, context) +def _describe_type(typ, plural=False): + if isinstance(typ, tuple): + *rest, last = typ + rest = ", ".join(_describe_type(t, plural=plural) for t in rest) + last = _describe_type(last, plural=plural) + return f"{rest} or {last}" if rest else last + + name = format_typehint(typ) + if plural: + return f"{name}s" + elif name[0].lower() in "aeiou": + return f"an {name}" + else: + return f"a {name}" + + class TypeOf(Slotted, Pattern): """Pattern that matches a value that is of a given type.""" @@ -774,6 +771,9 @@ class TypeOf(Slotted, Pattern): def __init__(self, typ): super().__init__(type=typ) + def describe(self, plural=False): + return f"exactly {_describe_type(self.type, plural=plural)}" + def match(self, value, context): if type(value) is self.type: return value @@ -795,6 +795,12 @@ class SubclassOf(Slotted, Pattern): def __init__(self, typ): super().__init__(type=typ) + def describe(self, plural=False): + if plural: + return f"subclasses of {self.type.__name__}" + else: + return f"a subclass of {self.type.__name__}" + def match(self, value, context): if issubclass(value, self.type): return value @@ -817,6 +823,9 @@ class InstanceOf(Slotted, Singleton, Pattern): def __init__(self, typ): super().__init__(type=typ) + def describe(self, plural=False): + return _describe_type(self.type, plural=plural) + def match(self, value, context): if isinstance(value, self.type): return value @@ -855,7 +864,7 @@ class GenericInstanceOf(Slotted, Pattern): >>> assert p.match(MyNumber(1), {}) is NoMatch """ - __slots__ = ("origin", "fields") + __slots__ = ("type", "origin", "fields") origin: type fields: FrozenDict[str, Pattern] @@ -871,7 +880,10 @@ def __init__(self, typ): ) fields[attr] = Pattern.from_typehint(type_, allow_coercion=False) - super().__init__(origin=origin, fields=frozendict(fields)) + super().__init__(type=typ, origin=origin, fields=frozendict(fields)) + + def describe(self, plural=False): + return _describe_type(self.type, plural=plural) def match(self, value, context): if not isinstance(value, self.origin): @@ -913,8 +925,7 @@ def match(self, value, context): return NoMatch -# TODO(kszucs): to support As[int] or CoercedTo[int] syntax -class CoercedTo(Slotted, Pattern): +class CoercedTo(Slotted, Pattern, Generic[T_co]): """Force a value to have a particular Python type. If a Coercible subclass is passed, the `__coerce__` method will be used to @@ -927,23 +938,24 @@ class CoercedTo(Slotted, Pattern): The type to coerce to. """ - __slots__ = ("target",) - target: type - - def __new__(cls, target): - if issubclass(target, Coercible): - return super().__new__(cls) - else: - return Apply(target) + __slots__ = ("target", "func") + target: T_co def __init__(self, target): - assert isinstance(target, type) - super().__init__(target=target) + func = target.__coerce__ if issubclass(target, Coercible) else target + super().__init__(target=target, func=func) + + def describe(self, plural=False): + target = _describe_type(self.target, plural=False) + if plural: + return f"coercibles to {target}" + else: + return f"coercible to {target}" def match(self, value, context): try: - value = self.target.__coerce__(value) - except CoercionError: + value = self.func(value) + except (TypeError, CoercionError): return NoMatch if isinstance(value, self.target): @@ -1005,6 +1017,12 @@ def __init__(self, target): params = frozendict(get_type_params(target)) super().__init__(origin=origin, params=params, checker=checker) + def describe(self, plural=False): + if plural: + return f"coercibles to {self.checker.describe(plural=False)}" + else: + return f"coercible to {self.checker.describe(plural=False)}" + def match(self, value, context): try: value = self.origin.__coerce__(value, **self.params) @@ -1032,6 +1050,12 @@ class Not(Slotted, Pattern): def __init__(self, inner): super().__init__(pattern=pattern(inner)) + def describe(self, plural=False): + if plural: + return f"anything except {self.pattern.describe(plural=True)}" + else: + return f"anything except {self.pattern.describe(plural=False)}" + def match(self, value, context): if self.pattern.match(value, context) is NoMatch: return value @@ -1056,6 +1080,12 @@ def __init__(self, *pats): patterns = tuple(map(pattern, pats)) super().__init__(patterns=patterns) + def describe(self, plural=False): + *rest, last = self.patterns + rest = ", ".join(p.describe(plural=plural) for p in rest) + last = last.describe(plural=plural) + return f"{rest} or {last}" if rest else last + def match(self, value, context): for pattern in self.patterns: result = pattern.match(value, context) @@ -1082,6 +1112,12 @@ def __init__(self, *pats): patterns = tuple(map(pattern, pats)) super().__init__(patterns=patterns) + def describe(self, plural=False): + *rest, last = self.patterns + rest = ", ".join(p.describe(plural=plural) for p in rest) + last = last.describe(plural=plural) + return f"{rest} then {last}" if rest else last + def match(self, value, context): for pattern in self.patterns: value = pattern.match(value, context) @@ -1121,6 +1157,19 @@ def __init__( at_most = exactly super().__init__(at_least=at_least, at_most=at_most) + def describe(self, plural=False): + if self.at_least is not None and self.at_most is not None: + if self.at_least == self.at_most: + return f"with length exactly {self.at_least}" + else: + return f"with length between {self.at_least} and {self.at_most}" + elif self.at_least is not None: + return f"with length at least {self.at_least}" + elif self.at_most is not None: + return f"with length at most {self.at_most}" + else: + return "with any length" + def match(self, value, context): length = len(value) if self.at_least is not None and length < self.at_least: @@ -1145,6 +1194,9 @@ class Contains(Slotted, Pattern): def __init__(self, needle): super().__init__(needle=needle) + def describe(self, plural=False): + return f"containing {self.needle!r}" + def match(self, value, context): if self.needle in value: return value @@ -1167,6 +1219,9 @@ class IsIn(Slotted, Pattern): def __init__(self, haystack): super().__init__(haystack=frozenset(haystack)) + def describe(self, plural=False): + return f"in {set(self.haystack)!r}" + def match(self, value, context): if value in self.haystack: return value @@ -1219,6 +1274,11 @@ def __new__( def __init__(self, item, type=tuple): super().__init__(item=pattern(item), type=type) + def describe(self, plural=False): + typ = _describe_type(self.type, plural=plural) + item = self.item.describe(plural=True) + return f"{typ} of {item}" + def match(self, values, context): if not is_iterable(values): return NoMatch @@ -1326,6 +1386,13 @@ def __init__(self, fields): fields = tuple(map(pattern, fields)) super().__init__(fields=fields) + def describe(self, plural=False): + fields = ", ".join(f.describe(plural=False) for f in self.fields) + if plural: + return f"tuples of ({fields})" + else: + return f"a tuple of ({fields})" + def match(self, values, context): if not is_iterable(values): return NoMatch diff --git a/ibis/common/tests/snapshots/test_annotations/test_annotated_function_without_decoration/error.txt b/ibis/common/tests/snapshots/test_annotations/test_annotated_function_without_decoration/error.txt new file mode 100644 index 000000000000..5c9855d211e7 --- /dev/null +++ b/ibis/common/tests/snapshots/test_annotations/test_annotated_function_without_decoration/error.txt @@ -0,0 +1,3 @@ +test(1, 2) missing a required argument: 'c' + +Expected signature: test(a: None, b: None, c: None) \ No newline at end of file diff --git a/ibis/common/tests/snapshots/test_annotations/test_signature_from_callable_with_keyword_only_arguments/missing_a_required_argument.txt b/ibis/common/tests/snapshots/test_annotations/test_signature_from_callable_with_keyword_only_arguments/missing_a_required_argument.txt new file mode 100644 index 000000000000..37f01ade44cc --- /dev/null +++ b/ibis/common/tests/snapshots/test_annotations/test_signature_from_callable_with_keyword_only_arguments/missing_a_required_argument.txt @@ -0,0 +1,3 @@ +test(2, 3) missing a required argument: 'c' + +Expected signature: test(a: int, b: int, *, c: float, d: float = 0.0) \ No newline at end of file diff --git a/ibis/common/tests/snapshots/test_annotations/test_signature_from_callable_with_keyword_only_arguments/too_many_positional_arguments.txt b/ibis/common/tests/snapshots/test_annotations/test_signature_from_callable_with_keyword_only_arguments/too_many_positional_arguments.txt new file mode 100644 index 000000000000..4af3834542c8 --- /dev/null +++ b/ibis/common/tests/snapshots/test_annotations/test_signature_from_callable_with_keyword_only_arguments/too_many_positional_arguments.txt @@ -0,0 +1,3 @@ +test(2, 3, 4) too many positional arguments + +Expected signature: test(a: int, b: int, *, c: float, d: float = 0.0) \ No newline at end of file diff --git a/ibis/common/tests/snapshots/test_annotations/test_signature_from_callable_with_positional_only_arguments/parameter_is_positional_only.txt b/ibis/common/tests/snapshots/test_annotations/test_signature_from_callable_with_positional_only_arguments/parameter_is_positional_only.txt new file mode 100644 index 000000000000..76ba59c37250 --- /dev/null +++ b/ibis/common/tests/snapshots/test_annotations/test_signature_from_callable_with_positional_only_arguments/parameter_is_positional_only.txt @@ -0,0 +1,3 @@ +test(1, b=2) 'b' parameter is positional only, but was passed as a keyword + +Expected signature: test(a: int, b: int, /, c: int = 1) \ No newline at end of file diff --git a/ibis/common/tests/snapshots/test_grounds/test_error_message/error_message.txt b/ibis/common/tests/snapshots/test_grounds/test_error_message/error_message.txt new file mode 100644 index 000000000000..96127f5e5246 --- /dev/null +++ b/ibis/common/tests/snapshots/test_grounds/test_error_message/error_message.txt @@ -0,0 +1,8 @@ +Example('1', '2', '3', '4', '5', []) has failed due to the following errors: + `a`: '1' is not an int + `b`: '2' is not an int + `d`: '4' is not either None or a float + `e`: '5' is not a tuple of ints + `f`: [] is not coercible to an int + +Expected signature: Example(a: int, b: int = 0, c: str = 'foo', d: Optional[float] = None, e: tuple = (1, 2, 3), f: CoercedTo[int] = 1) \ No newline at end of file diff --git a/ibis/common/tests/snapshots/test_grounds/test_error_message/error_message_py311.txt b/ibis/common/tests/snapshots/test_grounds/test_error_message/error_message_py311.txt new file mode 100644 index 000000000000..9bdc9ecec553 --- /dev/null +++ b/ibis/common/tests/snapshots/test_grounds/test_error_message/error_message_py311.txt @@ -0,0 +1,8 @@ +Example('1', '2', '3', '4', '5', []) has failed due to the following errors: + `a`: '1' is not an int + `b`: '2' is not an int + `d`: '4' is not either None or a float + `e`: '5' is not a tuple of ints + `f`: [] is not coercible to an int + +Expected signature: Example(a: int, b: int = 0, c: str = 'foo', d: Optional[float] = None, e: tuple[int, ...] = (1, 2, 3), f: CoercedTo[int] = 1) \ No newline at end of file diff --git a/ibis/common/tests/test_annotations.py b/ibis/common/tests/test_annotations.py index 3a12c710f875..1d01c0845d70 100644 --- a/ibis/common/tests/test_annotations.py +++ b/ibis/common/tests/test_annotations.py @@ -89,9 +89,7 @@ def test_argument_repr(): def test_default_argument(): annotation = Argument(pattern=lambda x, context: int(x), default=3) - assert annotation.validate(1) == 1 - with pytest.raises(TypeError): - annotation.validate(None) + assert annotation.pattern.match(1, {}) == 1 @pytest.mark.parametrize( @@ -100,7 +98,7 @@ def test_default_argument(): ) def test_optional_argument(default, expected): annotation = optional(default=default) - assert annotation.validate(None) == expected + assert annotation.pattern.match(None, {}) == expected @pytest.mark.parametrize( @@ -116,26 +114,25 @@ def test_optional_argument(default, expected): ], ) def test_valid_optional(argument, value, expected): - assert argument.validate(value) == expected + assert argument.pattern.match(value, {}) == expected -def test_invalid_optional_argument(): - with pytest.raises(ValidationError): - optional(is_int).validate("lynx") - - -def test_initialized(): +def test_attribute_default_value(): class Foo: a = 10 + assert not Attribute().has_default() + field = Attribute(default=lambda self: self.a + 10) + assert field.has_default() assert field == field - assert field.initialize(Foo) == 20 + assert field.get_default("b", Foo) == 20 field2 = Attribute(pattern=lambda x, this: str(x), default=lambda self: self.a) + assert field2.has_default() assert field != field2 - assert field2.initialize(Foo) == "10" + assert field2.get_default("b", Foo) == "10" def test_parameter(): @@ -147,16 +144,13 @@ def fn(x, this): assert p.annotation is annot assert p.default is inspect.Parameter.empty - assert p.annotation.validate("2", {"other": 1}) == 3 - - with pytest.raises(TypeError): - p.annotation.validate({}, valid=inspect.Parameter.empty) + assert p.annotation.pattern.match("2", {"other": 1}) == 3 ofn = optional(fn) op = Parameter("test", annotation=ofn) assert op.annotation.pattern == Option(fn, default=None) assert op.default is None - assert op.annotation.validate(None, {"other": 1}) is None + assert op.annotation.pattern.match(None, {"other": 1}) is None with pytest.raises(TypeError, match="annotation must be an instance of Argument"): Parameter("wrong", annotation=Attribute(lambda x, context: x)) @@ -173,9 +167,15 @@ def add_other(x, this): this = Parameter("this", annotation=Argument(add_other)) sig = Signature(parameters=[other, this]) - assert sig.validate(1, 2) == {"other": 1, "this": 3} - assert sig.validate(other=1, this=2) == {"other": 1, "this": 3} - assert sig.validate(this=2, other=1) == {"other": 1, "this": 3} + assert sig.validate(None, args=(1, 2), kwargs={}) == {"other": 1, "this": 3} + assert sig.validate(None, args=(), kwargs=dict(other=1, this=2)) == { + "other": 1, + "this": 3, + } + assert sig.validate(None, args=(), kwargs=dict(this=2, other=1)) == { + "other": 1, + "this": 3, + } def test_signature_from_callable(): @@ -183,12 +183,12 @@ def test(a: int, b: int, c: int = 1): ... sig = Signature.from_callable(test) - assert sig.validate(2, 3) == {"a": 2, "b": 3, "c": 1} + assert sig.validate(test, args=(2, 3), kwargs={}) == {"a": 2, "b": 3, "c": 1} with pytest.raises(ValidationError): - sig.validate(2, 3, "4") + sig.validate(test, args=(2, 3, "4"), kwargs={}) - args, kwargs = sig.unbind(sig.validate(2, 3)) + args, kwargs = sig.unbind(sig.validate(test, args=(2, 3), kwargs={})) assert args == (2, 3, 1) assert kwargs == {} @@ -198,53 +198,74 @@ def test(a: int, b: int, *args: int): ... sig = Signature.from_callable(test) - assert sig.validate(2, 3) == {"a": 2, "b": 3, "args": ()} - assert sig.validate(2, 3, 4) == {"a": 2, "b": 3, "args": (4,)} - assert sig.validate(2, 3, 4, 5) == {"a": 2, "b": 3, "args": (4, 5)} + assert sig.validate(test, args=(2, 3), kwargs={}) == {"a": 2, "b": 3, "args": ()} + assert sig.validate(test, args=(2, 3, 4), kwargs={}) == { + "a": 2, + "b": 3, + "args": (4,), + } + assert sig.validate(test, args=(2, 3, 4, 5), kwargs={}) == { + "a": 2, + "b": 3, + "args": (4, 5), + } assert sig.parameters["a"].annotation.typehint is int assert sig.parameters["b"].annotation.typehint is int assert sig.parameters["args"].annotation.typehint is int with pytest.raises(ValidationError): - sig.validate(2, 3, 4, "5") + sig.validate(test, args=(2, 3, 4, "5"), kwargs={}) - args, kwargs = sig.unbind(sig.validate(2, 3, 4, 5)) + args, kwargs = sig.unbind(sig.validate(test, args=(2, 3, 4, 5), kwargs={})) assert args == (2, 3, 4, 5) assert kwargs == {} -def test_signature_from_callable_with_positional_only_arguments(): +def test_signature_from_callable_with_positional_only_arguments(snapshot): def test(a: int, b: int, /, c: int = 1): ... sig = Signature.from_callable(test) - assert sig.validate(2, 3) == {"a": 2, "b": 3, "c": 1} - assert sig.validate(2, 3, 4) == {"a": 2, "b": 3, "c": 4} - assert sig.validate(2, 3, c=4) == {"a": 2, "b": 3, "c": 4} + assert sig.validate(test, args=(2, 3), kwargs={}) == {"a": 2, "b": 3, "c": 1} + assert sig.validate(test, args=(2, 3, 4), kwargs={}) == {"a": 2, "b": 3, "c": 4} + assert sig.validate(test, args=(2, 3), kwargs=dict(c=4)) == {"a": 2, "b": 3, "c": 4} - msg = "'b' parameter is positional only, but was passed as a keyword" - with pytest.raises(TypeError, match=msg): - sig.validate(1, b=2) + with pytest.raises(ValidationError) as excinfo: + sig.validate(test, args=(1,), kwargs=dict(b=2)) + snapshot.assert_match(str(excinfo.value), "parameter_is_positional_only.txt") - args, kwargs = sig.unbind(sig.validate(2, 3)) + args, kwargs = sig.unbind(sig.validate(test, args=(2, 3), kwargs={})) assert args == (2, 3, 1) assert kwargs == {} -def test_signature_from_callable_with_keyword_only_arguments(): +def test_signature_from_callable_with_keyword_only_arguments(snapshot): def test(a: int, b: int, *, c: float, d: float = 0.0): ... sig = Signature.from_callable(test) - assert sig.validate(2, 3, c=4.0) == {"a": 2, "b": 3, "c": 4.0, "d": 0.0} - assert sig.validate(2, 3, c=4.0, d=5.0) == {"a": 2, "b": 3, "c": 4.0, "d": 5.0} - - with pytest.raises(TypeError, match="missing a required argument: 'c'"): - sig.validate(2, 3) - with pytest.raises(TypeError, match="too many positional arguments"): - sig.validate(2, 3, 4) - - args, kwargs = sig.unbind(sig.validate(2, 3, c=4.0)) + assert sig.validate(test, args=(2, 3), kwargs=dict(c=4.0)) == { + "a": 2, + "b": 3, + "c": 4.0, + "d": 0.0, + } + assert sig.validate(test, args=(2, 3), kwargs=dict(c=4.0, d=5.0)) == { + "a": 2, + "b": 3, + "c": 4.0, + "d": 5.0, + } + + with pytest.raises(ValidationError) as excinfo: + sig.validate(test, args=(2, 3), kwargs={}) + snapshot.assert_match(str(excinfo.value), "missing_a_required_argument.txt") + + with pytest.raises(ValidationError) as excinfo: + sig.validate(test, args=(2, 3, 4), kwargs={}) + snapshot.assert_match(str(excinfo.value), "too_many_positional_arguments.txt") + + args, kwargs = sig.unbind(sig.validate(test, args=(2, 3), kwargs=dict(c=4.0))) assert args == (2, 3) assert kwargs == {"c": 4.0, "d": 0.0} @@ -260,7 +281,7 @@ def add_other(x, this): this = Parameter("this", annotation=Argument(add_other)) sig = Signature(parameters=[other, this]) - params = sig.validate(1, this=2) + params = sig.validate(None, args=(1,), kwargs=dict(this=2)) args, kwargs = sig.unbind(params) assert args == (1, 3) @@ -280,14 +301,14 @@ def add_other(x, this): @pytest.mark.parametrize("d", [(), (5, 6, 7)]) def test_signature_unbind_with_empty_variadic(d): - params = sig.validate(1, 2, 3, d, e=4) + params = sig.validate(None, args=(1, 2, 3, d), kwargs=dict(e=4)) assert params == {"a": 1.0, "b": 2.0, "c": 3.0, "d": d, "e": 4.0} args, kwargs = sig.unbind(params) assert args == (1.0, 2.0, 3.0, tuple(map(float, d)), 4.0) assert kwargs == {} - params_again = sig.validate(*args, **kwargs) + params_again = sig.validate(None, args=args, kwargs=kwargs) assert params_again == params @@ -400,11 +421,11 @@ def test(a: Annotated[str, short_str, endswith_d], b: Union[int, float]): assert test("abcd", 1) == ("abcd", 1) assert test("---d", 1.0) == ("---d", 1.0) - with pytest.raises(ValidationError, match="doesn't match"): + with pytest.raises(ValidationError): test("---c", 1) - with pytest.raises(ValidationError, match="doesn't match"): + with pytest.raises(ValidationError): test("123", 1) - with pytest.raises(ValidationError, match="'qweqwe' doesn't match"): + with pytest.raises(ValidationError): test("abcd", "qweqwe") @@ -417,13 +438,14 @@ def test(a, b, c): assert test.__signature__.parameters.keys() == {"a", "b", "c"} -def test_annotated_function_without_decoration(): +def test_annotated_function_without_decoration(snapshot): def test(a, b, c): return a + b + c func = annotated(test) - with pytest.raises(TypeError): + with pytest.raises(ValidationError) as excinfo: func(1, 2) + snapshot.assert_match(str(excinfo.value), "error.txt") assert func(1, 2, c=3) == 6 @@ -450,3 +472,14 @@ def test(a: float, b: float, **kwargs: int): with pytest.raises(ValidationError): test(1.0, 2.0, c=3, d=4, e=5, f=6.0) + + +def test_multiple_validation_failures(): + @annotated + def test(a: float, b: float, *args: int, **kwargs: int): + ... + + with pytest.raises(ValidationError) as excinfo: + test(1.0, 2.0, 3.0, 4, c=5.0, d=6) + + assert len(excinfo.value.errors) == 2 diff --git a/ibis/common/tests/test_grounds.py b/ibis/common/tests/test_grounds.py index 7e1f6aaba87c..74eccdc8637e 100644 --- a/ibis/common/tests/test_grounds.py +++ b/ibis/common/tests/test_grounds.py @@ -2,6 +2,7 @@ import copy import pickle +import sys import weakref from collections.abc import Mapping, Sequence from typing import Callable, Generic, Optional, TypeVar, Union @@ -29,6 +30,7 @@ ) from ibis.common.patterns import ( Any, + As, CoercedTo, Coercible, InstanceOf, @@ -508,7 +510,7 @@ class Test2(Test): c = is_int args = varargs(is_int) - with pytest.raises(TypeError, match="missing a required argument: 'c'"): + with pytest.raises(ValidationError, match="missing a required argument: 'c'"): Test2(1, 2) a = Test2(1, 2, 3) @@ -540,7 +542,7 @@ class Test2(Test): c = is_int options = varkwargs(is_int) - with pytest.raises(TypeError, match="missing a required argument: 'c'"): + with pytest.raises(ValidationError, match="missing a required argument: 'c'"): Test2(1, 2) a = Test2(1, 2, c=3) @@ -802,7 +804,7 @@ class Flexible(Annotable): def test_annotable_attribute(): - with pytest.raises(TypeError, match="too many positional arguments"): + with pytest.raises(ValidationError, match="too many positional arguments"): BaseValue(1, 2) v = BaseValue(1) @@ -1052,3 +1054,24 @@ class Example(Annotable): assert Example(None).value is None assert Example(1).value == 1 assert isinstance(Example(1).value, MyInt) + + +def test_error_message(snapshot): + class Example(Annotable): + a: int + b: int = 0 + c: str = "foo" + d: Optional[float] = None + e: tuple[int, ...] = (1, 2, 3) + f: As[int] = 1 + + with pytest.raises(ValidationError) as exc_info: + Example("1", "2", "3", "4", "5", []) + + # assert "Failed" in str(exc_info.value) + + if sys.version_info >= (3, 11): + target = "error_message_py311.txt" + else: + target = "error_message.txt" + snapshot.assert_match(str(exc_info.value), target) diff --git a/ibis/common/tests/test_patterns.py b/ibis/common/tests/test_patterns.py index 9c12bcd749a6..6788b1c34008 100644 --- a/ibis/common/tests/test_patterns.py +++ b/ibis/common/tests/test_patterns.py @@ -27,7 +27,6 @@ from ibis.common.graph import Node as GraphNode from ibis.common.patterns import ( AllOf, - Always, Any, AnyOf, Between, @@ -52,11 +51,11 @@ Length, ListOf, MappingOf, - Never, Node, NoMatch, NoneOf, Not, + Nothing, Object, Option, Pattern, @@ -114,14 +113,8 @@ def test_immutability_of_patterns(): p.types = [str] -def test_always(): - p = Always() - assert p.match(1, context={}) == 1 - assert p.match(2, context={}) == 2 - - -def test_never(): - p = Never() +def test_nothing(): + p = Nothing() assert p.match(1, context={}) is NoMatch assert p.match(2, context={}) is NoMatch @@ -178,39 +171,78 @@ def test_capture(): def test_option(): p = Option(InstanceOf(str)) + assert Option(str) == p assert p.match(None, context={}) is None assert p.match("foo", context={}) == "foo" assert p.match(1, context={}) is NoMatch + assert p.describe() == "either None or a str" + assert p.describe(plural=True) == "optional strs" + + p = Option(int, default=-1) + assert p.match(None, context={}) == -1 + assert p.match(1, context={}) == 1 + assert p.match(1.0, context={}) is NoMatch + assert p.describe() == "either None or an int" + assert p.describe(plural=True) == "optional ints" def test_check(): - p = Check(lambda x: x == 10) + def checker(x): + return x == 10 + + p = Check(checker) assert p.match(10, context={}) == 10 assert p.match(11, context={}) is NoMatch + assert p.describe() == "a value that satisfies checker()" + assert p.describe(plural=True) == "values that satisfy checker()" def test_equal_to(): p = EqualTo(10) assert p.match(10, context={}) == 10 assert p.match(11, context={}) is NoMatch + assert p.describe() == "10" + assert p.describe(plural=True) == "10" + + p = EqualTo("10") + assert p.match(10, context={}) is NoMatch + assert p.match("10", context={}) == "10" + assert p.describe() == "'10'" + assert p.describe(plural=True) == "'10'" def test_type_of(): p = TypeOf(int) assert p.match(1, context={}) == 1 assert p.match("foo", context={}) is NoMatch + assert p.describe() == "exactly an int" + assert p.describe(plural=True) == "exactly ints" def test_subclass_of(): p = SubclassOf(Pattern) assert p.match(Double, context={}) == Double assert p.match(int, context={}) is NoMatch + assert p.describe() == "a subclass of Pattern" + assert p.describe(plural=True) == "subclasses of Pattern" def test_instance_of(): p = InstanceOf(int) assert p.match(1, context={}) == 1 assert p.match("foo", context={}) is NoMatch + assert p.describe() == "an int" + assert p.describe(plural=True) == "ints" + + p = InstanceOf((int, str)) + assert p.match(1, context={}) == 1 + assert p.match("foo", context={}) == "foo" + assert p.match(1.0, context={}) is NoMatch + assert p.describe() == "an int or a str" + assert p.describe(plural=True) == "ints or strs" + + p = InstanceOf((int, str, float)) + assert p.describe() == "an int, a str or a float" def test_lazy_instance_of(): @@ -233,6 +265,7 @@ class My(Generic[T, S]): def test_generic_instance_of_with_covariant_typevar(): p = Pattern.from_typehint(My[int, AnyType]) assert p.match(My(1, 2, "3"), context={}) == My(1, 2, "3") + assert p.describe() == "a My[int, Any]" assert match(My[int, AnyType], v := My(1, 2, "3")) == v assert match(My[int, int], v := My(1, 2, "3")) == v @@ -327,6 +360,8 @@ def __eq__(self, other): p = Pattern.from_typehint(Literal[String]) r = p.match("foo", context={}) assert r == Literal("foo", Scalar()) + expected = "coercible to a .Literal[.String]" + assert p.describe() == expected def test_not(): @@ -336,6 +371,8 @@ def test_not(): assert p == p1 assert p.match(1, context={}) is NoMatch assert p.match("foo", context={}) == "foo" + assert p.describe() == "anything except an int" + assert p.describe(plural=True) == "anything except ints" def test_any_of(): @@ -346,6 +383,11 @@ def test_any_of(): assert p.match(1, context={}) == 1 assert p.match("foo", context={}) == "foo" assert p.match(1.0, context={}) is NoMatch + assert p.describe() == "an int or a str" + assert p.describe(plural=True) == "ints or strs" + + p = AnyOf(InstanceOf(int), InstanceOf(str), InstanceOf(float)) + assert p.describe() == "an int, a str or a float" def test_all_of(): @@ -358,6 +400,14 @@ def negative(x): assert p == p1 assert p.match(1, context={}) is NoMatch assert p.match(-1, context={}) == -1 + assert p.match(1.0, context={}) is NoMatch + assert p.describe() == "an int then a value that satisfies negative()" + + p = AllOf(InstanceOf(int), CoercedTo(float), CoercedTo(str)) + assert p.match(1, context={}) == "1.0" + assert p.match(1.0, context={}) is NoMatch + assert p.match("1", context={}) is NoMatch + assert p.describe() == "an int, coercible to a float then coercible to a str" def test_none_of(): @@ -368,6 +418,7 @@ def negative(x): assert p.match(1.0, context={}) == 1.0 assert p.match(-1.0, context={}) is NoMatch assert p.match(1, context={}) is NoMatch + assert p.describe() == "anything except an int or a value that satisfies negative()" def test_length(): @@ -379,14 +430,17 @@ def test_length(): p = Length(exactly=3) assert p.match([1, 2, 3], context={}) == [1, 2, 3] assert p.match([1, 2], context={}) is NoMatch + assert p.describe() == "with length exactly 3" p = Length(at_least=3) assert p.match([1, 2, 3], context={}) == [1, 2, 3] assert p.match([1, 2], context={}) is NoMatch + assert p.describe() == "with length at least 3" p = Length(at_most=3) assert p.match([1, 2, 3], context={}) == [1, 2, 3] assert p.match([1, 2, 3, 4], context={}) is NoMatch + assert p.describe() == "with length at most 3" p = Length(at_least=3, at_most=5) assert p.match([1, 2], context={}) is NoMatch @@ -394,18 +448,31 @@ def test_length(): assert p.match([1, 2, 3, 4], context={}) == [1, 2, 3, 4] assert p.match([1, 2, 3, 4, 5], context={}) == [1, 2, 3, 4, 5] assert p.match([1, 2, 3, 4, 5, 6], context={}) is NoMatch + assert p.describe() == "with length between 3 and 5" def test_contains(): p = Contains(1) assert p.match([1, 2, 3], context={}) == [1, 2, 3] assert p.match([2, 3], context={}) is NoMatch + assert p.match({1, 2, 3}, context={}) == {1, 2, 3} + assert p.match({2, 3}, context={}) is NoMatch + assert p.describe() == "containing 1" + assert p.describe(plural=True) == "containing 1" + + p = Contains("1") + assert p.match([1, 2, 3], context={}) is NoMatch + assert p.match(["1", 2, 3], context={}) == ["1", 2, 3] + assert p.match("123", context={}) == "123" + assert p.describe() == "containing '1'" def test_isin(): p = IsIn([1, 2, 3]) assert p.match(1, context={}) == 1 assert p.match(4, context={}) is NoMatch + assert p.describe() == "in {1, 2, 3}" + assert p.describe(plural=True) == "in {1, 2, 3}" def test_sequence_of(): @@ -415,6 +482,8 @@ def test_sequence_of(): assert p.match([1, 2], context={}) is NoMatch assert p.match(1, context={}) is NoMatch assert p.match("string", context={}) is NoMatch + assert p.describe() == "a list of strs" + assert p.describe(plural=True) == "lists of strs" def test_generic_sequence_of(): @@ -447,6 +516,8 @@ def test_list_of(): assert p.match(["foo", "bar"], context={}) == ["foo", "bar"] assert p.match([1, 2], context={}) is NoMatch assert p.match(1, context={}) is NoMatch + assert p.describe() == "a list of strs" + assert p.describe(plural=True) == "lists of strs" def test_tuple_of(): @@ -454,12 +525,16 @@ def test_tuple_of(): assert p.match(("foo", 1, 1.0), context={}) == ("foo", 1, 1.0) assert p.match(["foo", 1, 1.0], context={}) == ("foo", 1, 1.0) assert p.match(1, context={}) is NoMatch + assert p.describe() == "a tuple of (a str, an int, a float)" + assert p.describe(plural=True) == "tuples of (a str, an int, a float)" p = TupleOf(InstanceOf(str)) assert p == SequenceOf(InstanceOf(str), tuple) assert p.match(("foo", "bar"), context={}) == ("foo", "bar") assert p.match(["foo"], context={}) == ("foo",) assert p.match(1, context={}) is NoMatch + assert p.describe() == "a tuple of strs" + assert p.describe(plural=True) == "tuples of strs" def test_mapping_of(): @@ -556,7 +631,7 @@ def func_with_required_keyword_only_kwargs(*, c): wrapped = p.match(func, context={}) assert wrapped(1, "st") == "1st" - with pytest.raises(ValidationError, match="2 doesn't match InstanceOf"): + with pytest.raises(ValidationError): wrapped(1, 2) p = CallableWith([InstanceOf(int)]) @@ -959,7 +1034,7 @@ def __coerce__(cls, obj): if isinstance(obj, cls): return obj else: - raise TypeError("raise on coercion") + raise ValueError("raise on coercion") class PlusOneChild(PlusOne): @@ -983,7 +1058,7 @@ def test_pattern_coercible_bypass_coercion(): # bypass coercion since it's already an instance of SomethingRaise assert s.match(PlusOneRaise(10), context={}) == PlusOneRaise(10) # but actually call __coerce__ if it's not an instance - with pytest.raises(TypeError, match="raise on coercion"): + with pytest.raises(ValueError, match="raise on coercion"): s.match(10, context={}) diff --git a/ibis/common/typing.py b/ibis/common/typing.py index 2a2e403a6249..df70e7500a90 100644 --- a/ibis/common/typing.py +++ b/ibis/common/typing.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re import sys from itertools import zip_longest from types import ModuleType # noqa: F401 @@ -194,6 +195,19 @@ def evaluate_annotations( } +def format_typehint(typ: Any) -> str: + if isinstance(typ, type): + return typ.__name__ + elif isinstance(typ, TypeVar): + if typ.__bound__ is None: + return str(typ) + else: + return format_typehint(typ.__bound__) + else: + # remove the module name from the typehint, including generics + return re.sub(r"(\w+\.)+", "", str(typ)) + + class DefaultTypeVars: """Enable using default type variables in generic classes (PEP-0696).""" diff --git a/ibis/expr/datatypes/core.py b/ibis/expr/datatypes/core.py index a72b2272bd6c..fbae85cbcce2 100644 --- a/ibis/expr/datatypes/core.py +++ b/ibis/expr/datatypes/core.py @@ -121,7 +121,7 @@ def __coerce__(cls, value, **kwargs): return value try: return dtype(value) - except TypeError as e: + except (TypeError, RuntimeError) as e: raise CoercionError("Unable to coerce to a DataType") from e def __call__(self, **kwargs): @@ -165,7 +165,12 @@ def from_typehint(cls, typ, nullable=True) -> Self: if origin_type is None: if isinstance(typ, type): - if issubclass(typ, DataType): + if issubclass(typ, Parametric): + raise TypeError( + f"Cannot construct a parametric {typ.__name__} datatype based " + "on the type itself" + ) + elif issubclass(typ, DataType): return typ(nullable=nullable) elif typ is type(None): return null diff --git a/ibis/expr/datatypes/tests/test_core.py b/ibis/expr/datatypes/tests/test_core.py index b3c501fee1aa..37fcbcb386fc 100644 --- a/ibis/expr/datatypes/tests/test_core.py +++ b/ibis/expr/datatypes/tests/test_core.py @@ -69,8 +69,6 @@ def test_dtype(spec, expected): (dt.Boolean, dt.boolean), (dt.Date, dt.date), (dt.Time, dt.time), - (dt.Decimal, dt.decimal), - (dt.Timestamp, dt.timestamp), ], ) def test_dtype_from_classes(klass, expected): @@ -145,8 +143,6 @@ class BarStruct: l: dt.Boolean # noqa: E741 m: dt.Date n: dt.Time - o: dt.Timestamp - q: dt.Decimal r: dt.Array[dt.Int16] s: dt.Map[dt.String, dt.Int16] @@ -167,8 +163,6 @@ class BarStruct: "l": dt.boolean, "m": dt.date, "n": dt.time, - "o": dt.timestamp, - "q": dt.decimal, "r": dt.Array(dt.int16), "s": dt.Map(dt.string, dt.int16), } diff --git a/ibis/expr/operations/core.py b/ibis/expr/operations/core.py index 728e0a3bea4b..890e88974367 100644 --- a/ibis/expr/operations/core.py +++ b/ibis/expr/operations/core.py @@ -75,7 +75,7 @@ def __coerce__( try: try: - dtype = dt.dtype(T) + dtype = dt.DataType.from_typehint(T) except TypeError: dtype = dt.infer(value) return Literal(value, dtype=dtype) diff --git a/ibis/expr/operations/tests/snapshots/test_generic/test_error_message_when_constructing_literal/call0-missing_a_required_argument/missing_a_required_argument.txt b/ibis/expr/operations/tests/snapshots/test_generic/test_error_message_when_constructing_literal/call0-missing_a_required_argument/missing_a_required_argument.txt new file mode 100644 index 000000000000..7033424e9d68 --- /dev/null +++ b/ibis/expr/operations/tests/snapshots/test_generic/test_error_message_when_constructing_literal/call0-missing_a_required_argument/missing_a_required_argument.txt @@ -0,0 +1,3 @@ +Literal(1) missing a required argument: 'dtype' + +Expected signature: Literal(value: Any, dtype: DataType) \ No newline at end of file diff --git a/ibis/expr/operations/tests/snapshots/test_generic/test_error_message_when_constructing_literal/call1-too_many_positional_arguments/too_many_positional_arguments.txt b/ibis/expr/operations/tests/snapshots/test_generic/test_error_message_when_constructing_literal/call1-too_many_positional_arguments/too_many_positional_arguments.txt new file mode 100644 index 000000000000..8860b39b90a4 --- /dev/null +++ b/ibis/expr/operations/tests/snapshots/test_generic/test_error_message_when_constructing_literal/call1-too_many_positional_arguments/too_many_positional_arguments.txt @@ -0,0 +1,3 @@ +Literal(1, Int8(nullable=True), 'foo') too many positional arguments + +Expected signature: Literal(value: Any, dtype: DataType) \ No newline at end of file diff --git a/ibis/expr/operations/tests/snapshots/test_generic/test_error_message_when_constructing_literal/call2-got_an_unexpected_keyword/got_an_unexpected_keyword.txt b/ibis/expr/operations/tests/snapshots/test_generic/test_error_message_when_constructing_literal/call2-got_an_unexpected_keyword/got_an_unexpected_keyword.txt new file mode 100644 index 000000000000..9122a7c2691d --- /dev/null +++ b/ibis/expr/operations/tests/snapshots/test_generic/test_error_message_when_constructing_literal/call2-got_an_unexpected_keyword/got_an_unexpected_keyword.txt @@ -0,0 +1,3 @@ +Literal(1, Int8(nullable=True), name='foo') got an unexpected keyword argument 'name' + +Expected signature: Literal(value: Any, dtype: DataType) \ No newline at end of file diff --git a/ibis/expr/operations/tests/snapshots/test_generic/test_error_message_when_constructing_literal/call3-multiple_values_for_argument/multiple_values_for_argument.txt b/ibis/expr/operations/tests/snapshots/test_generic/test_error_message_when_constructing_literal/call3-multiple_values_for_argument/multiple_values_for_argument.txt new file mode 100644 index 000000000000..1cb3ba43bde7 --- /dev/null +++ b/ibis/expr/operations/tests/snapshots/test_generic/test_error_message_when_constructing_literal/call3-multiple_values_for_argument/multiple_values_for_argument.txt @@ -0,0 +1,3 @@ +Literal(1, Int8(nullable=True), dtype=Int16(nullable=True)) multiple values for argument 'dtype' + +Expected signature: Literal(value: Any, dtype: DataType) \ No newline at end of file diff --git a/ibis/expr/operations/tests/snapshots/test_generic/test_error_message_when_constructing_literal/call4-invalid_dtype/invalid_dtype.txt b/ibis/expr/operations/tests/snapshots/test_generic/test_error_message_when_constructing_literal/call4-invalid_dtype/invalid_dtype.txt new file mode 100644 index 000000000000..3c82299246e0 --- /dev/null +++ b/ibis/expr/operations/tests/snapshots/test_generic/test_error_message_when_constructing_literal/call4-invalid_dtype/invalid_dtype.txt @@ -0,0 +1,4 @@ +Literal(1, 4) has failed due to the following errors: + `dtype`: 4 is not coercible to a DataType + +Expected signature: Literal(value: Any, dtype: DataType) \ No newline at end of file diff --git a/ibis/expr/operations/tests/snapshots/test_generic/test_error_message_when_constructing_literal/call5-unable_to_normalize/unable_to_normalize.txt b/ibis/expr/operations/tests/snapshots/test_generic/test_error_message_when_constructing_literal/call5-unable_to_normalize/unable_to_normalize.txt new file mode 100644 index 000000000000..58caccc21cef --- /dev/null +++ b/ibis/expr/operations/tests/snapshots/test_generic/test_error_message_when_constructing_literal/call5-unable_to_normalize/unable_to_normalize.txt @@ -0,0 +1 @@ +Unable to normalize 'e' to Int8(nullable=True) \ No newline at end of file diff --git a/ibis/expr/operations/tests/test_core.py b/ibis/expr/operations/tests/test_core.py index c0b9717c1005..4065298bb18a 100644 --- a/ibis/expr/operations/tests/test_core.py +++ b/ibis/expr/operations/tests/test_core.py @@ -198,10 +198,10 @@ def test_too_many_or_too_few_args_not_allowed(): class DummyOp(ops.Value): arg: ops.Value - with pytest.raises(TypeError): + with pytest.raises(ValidationError): DummyOp(1, 2) - with pytest.raises(TypeError): + with pytest.raises(ValidationError): DummyOp() diff --git a/ibis/expr/operations/tests/test_generic.py b/ibis/expr/operations/tests/test_generic.py index 1d9f54361c62..3e0a17495a1e 100644 --- a/ibis/expr/operations/tests/test_generic.py +++ b/ibis/expr/operations/tests/test_generic.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import partial from typing import Union import pytest @@ -7,6 +8,7 @@ import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.operations as ops +from ibis.common.annotations import ValidationError from ibis.common.patterns import NoMatch, match @@ -76,3 +78,22 @@ def test_coerced_to_interval_value(): expected = ops.Literal(3661, dt.Interval("s")) assert match(ops.Value[dt.Interval], pd.Timedelta("1h 1m 1s")) == expected + + +@pytest.mark.parametrize( + ("call", "error"), + [ + (partial(ops.Literal, 1), "missing_a_required_argument"), + (partial(ops.Literal, 1, dt.int8, "foo"), "too_many_positional_arguments"), + (partial(ops.Literal, 1, dt.int8, name="foo"), "got_an_unexpected_keyword"), + ( + partial(ops.Literal, 1, dt.int8, dtype=dt.int16), + "multiple_values_for_argument", + ), + (partial(ops.Literal, 1, 4), "invalid_dtype"), + ], +) +def test_error_message_when_constructing_literal(call, error, snapshot): + with pytest.raises(ValidationError) as exc: + call() + snapshot.assert_match(str(exc.value), f"{error}.txt") diff --git a/ibis/tests/expr/test_table.py b/ibis/tests/expr/test_table.py index 6252fdfe0f1a..6e234adc223a 100644 --- a/ibis/tests/expr/test_table.py +++ b/ibis/tests/expr/test_table.py @@ -1860,7 +1860,7 @@ def test_pivot_wider(): def test_invalid_deferred(): t = ibis.table(dict(value="int", lagged_value="int"), name="t") - with pytest.raises(ValidationError, match="doesn't match"): + with pytest.raises(ValidationError): ibis.greatest(t.value, ibis._.lagged_value) diff --git a/ibis/tests/expr/test_window_frames.py b/ibis/tests/expr/test_window_frames.py index 6a657c78e749..ca81290686f2 100644 --- a/ibis/tests/expr/test_window_frames.py +++ b/ibis/tests/expr/test_window_frames.py @@ -56,7 +56,7 @@ def test_window_builder_rows(): assert w0.start is None assert w0.end is None - with pytest.raises(TypeError): + with pytest.raises(ValidationError): w0.rows(5) w1 = w0.rows(5, 10) @@ -104,7 +104,7 @@ def test_window_builder_range(): assert w0.start is None assert w0.end is None - with pytest.raises(TypeError): + with pytest.raises(ValidationError): w0.range(5) w1 = w0.range(5, 10)