Skip to content

Commit

Permalink
refactor(common): improve error messages raised during validation
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs authored and cpcloud committed Sep 10, 2023
1 parent 75982d4 commit f95613a
Show file tree
Hide file tree
Showing 27 changed files with 626 additions and 196 deletions.
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def test_dropna_invalid(alltypes):
):
alltypes.dropna(subset=["invalid_col"])

with pytest.raises(ValidationError, match=r"'invalid' doesn't match"):
with pytest.raises(ValidationError):
alltypes.dropna(how="invalid")


Expand Down
217 changes: 185 additions & 32 deletions ibis/common/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import functools
import inspect
import types
from typing import TYPE_CHECKING, Callable
from typing import Any as AnyType

from ibis.common.bases import Immutable, Slotted
Expand All @@ -15,7 +16,10 @@
TupleOf,
)
from ibis.common.patterns import pattern as ensure_pattern
from ibis.common.typing import get_type_hints
from ibis.common.typing import format_typehint, get_type_hints

if TYPE_CHECKING:
from collections.abc import Sequence

EMPTY = inspect.Parameter.empty # marker for missing argument
KEYWORD_ONLY = inspect.Parameter.KEYWORD_ONLY
Expand All @@ -29,7 +33,65 @@


class ValidationError(Exception):
...
__slots__ = ()


class AttributeValidationError(ValidationError):
__slots__ = ("name", "value", "pattern")

def __init__(self, name: str, value: AnyType, pattern: Pattern):
self.name = name
self.value = value
self.pattern = pattern

def __str__(self):
return f"Failed to validate attribute `{self.name}`: {self.value!r} is not {self.pattern.describe()}"


class ReturnValidationError(ValidationError):
__slots__ = ("func", "value", "pattern")

def __init__(self, func: Callable, value: AnyType, pattern: Pattern):
self.func = func
self.value = value
self.pattern = pattern

def __str__(self):
return f"Failed to validate return value of `{self.func.__name__}`: {self.value!r} is not {self.pattern.describe()}"


class SignatureValidationError(ValidationError):
__slots__ = ("msg", "sig", "func", "args", "kwargs", "errors")

def __init__(
self,
msg: str,
sig: Signature,
func: Callable,
args: tuple[AnyType, ...],
kwargs: dict[str, AnyType],
errors: Sequence[tuple[str, AnyType, Pattern]] = (),
):
self.msg = msg
self.sig = sig
self.func = func
self.args = args
self.kwargs = kwargs
self.errors = errors

def __str__(self):
args = tuple(repr(arg) for arg in self.args)
args += tuple(f"{k}={v!r}" for k, v in self.kwargs.items())
call = f"{self.func.__name__}({', '.join(args)})"

errors = ""
for name, value, pattern in self.errors:
errors += f"\n `{name}`: {value!r} is not {pattern.describe()}"

sig = f"{self.func.__name__}{self.sig}"
cause = str(self.__cause__) if self.__cause__ else ""

return self.msg.format(sig=sig, call=call, cause=cause, errors=errors)


class Annotation(Slotted, Immutable):
Expand All @@ -40,11 +102,29 @@ class Annotation(Slotted, Immutable):

__slots__ = ()

def validate(self, arg, context=None):
result = self.pattern.match(arg, context)
if result is NoMatch:
raise ValidationError(f"{arg!r} doesn't match {self.pattern!r}")
def validate(self, name: str, value: AnyType, this: AnyType) -> AnyType:
"""Validate the field.
Parameters
----------
name
The name of the attribute.
value
The value of the attribute.
this
The instance of the class the attribute is defined on.
Returns
-------
The validated value for the field.
"""
result = self.pattern.match(value, this)
if result is NoMatch:
raise AttributeValidationError(
name=name,
value=value,
pattern=self.pattern,
)
return result


Expand All @@ -69,25 +149,34 @@ class Attribute(Annotation):
def __init__(self, pattern: Pattern = _any, default: AnyType = EMPTY):
super().__init__(pattern=ensure_pattern(pattern), default=default)

def initialize(self, this: AnyType) -> AnyType:
"""Compute the default value of the field.
def has_default(self):
"""Check if the field has a default value.
Returns
-------
bool
"""
return self.default is not EMPTY

def get_default(self, name: str, this: AnyType) -> AnyType:
"""Get the default value of the field.
Parameters
----------
name
The name of the attribute.
this
The instance of the class the attribute is defined on.
Returns
-------
The default value for the field.
"""
if self.default is EMPTY:
return EMPTY
elif callable(self.default):
if callable(self.default):
value = self.default(this)
else:
value = self.default
return self.validate(value, this)
return self.validate(name, value, this)

def __call__(self, default):
"""Needed to support the decorator syntax."""
Expand Down Expand Up @@ -180,6 +269,26 @@ def __init__(self, name, annotation):
annotation=annotation,
)

def __str__(self):
formatted = self._name

if self._annotation is not EMPTY:
typehint = format_typehint(self._annotation.typehint)
formatted = f"{formatted}: {typehint}"

if self._default is not EMPTY:
if self._annotation is not EMPTY:
formatted = f"{formatted} = {self._default!r}"
else:
formatted = f"{formatted}={self._default!r}"

if self._kind == VAR_POSITIONAL:
formatted = "*" + formatted
elif self._kind == VAR_KEYWORD:
formatted = "**" + formatted

return formatted


class Signature(inspect.Signature):
"""Validatable signature.
Expand Down Expand Up @@ -339,11 +448,13 @@ def unbind(self, this: dict[str, Any]) -> tuple[tuple[Any, ...], dict[str, Any]]
raise TypeError(f"unsupported parameter kind {param.kind}")
return tuple(args), kwargs

def validate(self, *args, **kwargs):
def validate(self, func, args, kwargs):
"""Validate the arguments against the signature.
Parameters
----------
func : Callable
Callable to validate the arguments for.
args : tuple
Positional arguments.
kwargs : dict
Expand All @@ -354,39 +465,77 @@ def validate(self, *args, **kwargs):
validated : dict
Dictionary of validated arguments.
"""
# bind the signature to the passed arguments and apply the patterns
# before passing the arguments, so self.__init__() receives already
# validated arguments as keywords
bound = self.bind(*args, **kwargs)
bound.apply_defaults()

this = {}
try:
bound = self.bind(*args, **kwargs)
bound.apply_defaults()
except TypeError as err:
raise SignatureValidationError(
"{call} {cause}\n\nExpected signature: {sig}",
sig=self,
func=func,
args=args,
kwargs=kwargs,
) from err

this, errors = {}, []
for name, value in bound.arguments.items():
param = self.parameters[name]
# TODO(kszucs): provide more error context on failure
this[name] = param.annotation.validate(value, this)
pattern = param.annotation.pattern

result = pattern.match(value, this)
if result is NoMatch:
errors.append((name, value, pattern))
else:
this[name] = result

if errors:
raise SignatureValidationError(
"{call} has failed due to the following errors:{errors}\n\nExpected signature: {sig}",
sig=self,
func=func,
args=args,
kwargs=kwargs,
errors=errors,
)

return this

def validate_nobind(self, **kwargs):
def validate_nobind(self, func, kwargs):
"""Validate the arguments against the signature without binding."""
this = {}
this, errors = {}, []
for name, param in self.parameters.items():
value = kwargs.get(name, param.default)
if value is EMPTY:
raise TypeError(f"missing required argument `{name!r}`")
this[name] = param.annotation.validate(value, kwargs)

pattern = param.annotation.pattern
result = pattern.match(value, this)
if result is NoMatch:
errors.append((name, value, pattern))
else:
this[name] = result

if errors:
raise SignatureValidationError(
"{call} has failed due to the following errors:{errors}\n\nExpected signature: {sig}",
sig=self,
func=func,
args=(),
kwargs=kwargs,
errors=errors,
)

return this

def validate_return(self, value, context):
def validate_return(self, func, value):
"""Validate the return value of a function.
Parameters
----------
func : Callable
Callable to validate the return value for.
value : Any
Return value of the function.
context : dict
Context dictionary.
Returns
-------
Expand All @@ -396,9 +545,13 @@ def validate_return(self, value, context):
if self.return_annotation is EMPTY:
return value

result = self.return_annotation.match(value, context)
result = self.return_annotation.match(value, {})
if result is NoMatch:
raise ValidationError(f"{value!r} doesn't match {self}")
raise ReturnValidationError(
func=func,
value=value,
pattern=self.return_annotation,
)

return result

Expand Down Expand Up @@ -476,13 +629,13 @@ def annotated(_1=None, _2=None, _3=None, **kwargs):
@functools.wraps(func)
def wrapped(*args, **kwargs):
# 1. Validate the passed arguments
values = sig.validate(*args, **kwargs)
values = sig.validate(func, args, kwargs)
# 2. Reconstruction of the original arguments
args, kwargs = sig.unbind(values)
# 3. Call the function with the validated arguments
result = func(*args, **kwargs)
# 4. Validate the return value
return sig.validate_return(result, {})
return sig.validate_return(func, result)

wrapped.__signature__ = sig

Expand Down
24 changes: 12 additions & 12 deletions ibis/common/grounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from typing_extensions import Self, dataclass_transform

from ibis.common.annotations import (
EMPTY,
Annotation,
Argument,
Attribute,
Expand Down Expand Up @@ -115,14 +114,14 @@ class Annotable(Base, metaclass=AnnotableMeta):

@classmethod
def __create__(cls, *args: Any, **kwargs: Any) -> Self:
# construct the instance by passing the validated keyword arguments
kwargs = cls.__signature__.validate(*args, **kwargs)
# construct the instance by passing only validated keyword arguments
kwargs = cls.__signature__.validate(cls, args, kwargs)
return super().__create__(**kwargs)

@classmethod
def __recreate__(cls, kwargs: Any) -> Self:
# bypass signature binding by requiring keyword arguments only
kwargs = cls.__signature__.validate_nobind(**kwargs)
kwargs = cls.__signature__.validate_nobind(cls, kwargs)
return super().__create__(**kwargs)

def __init__(self, **kwargs: Any) -> None:
Expand All @@ -131,16 +130,17 @@ def __init__(self, **kwargs: Any) -> None:
object.__setattr__(self, name, value)
# initialize the remaining attributes
for name, field in self.__attributes__.items():
if (default := field.initialize(self)) is not EMPTY:
object.__setattr__(self, name, default)
if field.has_default():
object.__setattr__(self, name, field.get_default(name, self))

def __setattr__(self, name, value) -> None:
# first try to look up the argument then the attribute
if param := self.__signature__.parameters.get(name):
value = param.annotation.validate(value, self)
elif field := self.__attributes__.get(name):
value = field.validate(value, self)
super().__setattr__(name, value)
value = param.annotation.validate(name, value, self)
# then try to look up the attribute
elif annot := self.__attributes__.get(name):
value = annot.validate(name, value, self)
return super().__setattr__(name, value)

def __repr__(self) -> str:
args = (f"{n}={getattr(self, n)!r}" for n in self.__argnames__)
Expand Down Expand Up @@ -204,8 +204,8 @@ def __init__(self, **kwargs: Any) -> None:

# initialize the remaining attributes
for name, field in self.__attributes__.items():
if (default := field.initialize(self)) is not EMPTY:
object.__setattr__(self, name, default)
if field.has_default():
object.__setattr__(self, name, field.get_default(name, self))

def __reduce__(self):
# assuming immutability and idempotency of the __init__ method, we can
Expand Down
Loading

0 comments on commit f95613a

Please sign in to comment.