Skip to content

Commit

Permalink
Exception handling.
Browse files Browse the repository at this point in the history
  • Loading branch information
peterschutt committed Jun 30, 2023
1 parent f2300ef commit b7cb2e2
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 96 deletions.
53 changes: 24 additions & 29 deletions litestar/_signature/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
from litestar.connection import ASGIConnection
from litestar.utils.signature import ParsedSignature

__all__ = ("SignatureModel",)
__all__ = (
"ErrorMessage",
"SignatureModel",
)


class ErrorMessage(TypedDict):
Expand Down Expand Up @@ -53,9 +56,7 @@ def _create_exception(cls, connection: ASGIConnection, messages: list[ErrorMessa
if ("key" in err_message and err_message["key"] not in cls.dependency_name_set) or "key" not in err_message
]:
return ValidationException(detail=f"Validation failed for {method} {connection.url}", extra=client_errors)
return InternalServerException(
detail=f"A dependency failed validation for {method} {connection.url}", extra=messages
)
return InternalServerException

@classmethod
def _build_error_message(cls, keys: Sequence[str], exc_msg: str, connection: ASGIConnection) -> ErrorMessage:
Expand All @@ -70,31 +71,25 @@ def _build_error_message(cls, keys: Sequence[str], exc_msg: str, connection: ASG
An ErrorMessage
"""

message: ErrorMessage = {"message": exc_msg}

if len(keys) > 1:
key_start = 0

if keys[0] == "data":
key_start = 1
message["source"] = "body"

message["key"] = ".".join(keys[key_start:])
elif keys:
key = keys[0]
message["key"] = key

if key in connection.query_params:
message["source"] = cast("Literal['cookie', 'body', 'header', 'query']", "query")

elif key in cls.fields and isinstance(cls.fields[key].kwarg_model, ParameterKwarg):
if cast(ParameterKwarg, cls.fields[key].kwarg_model).cookie:
source = "cookie"
elif cast(ParameterKwarg, cls.fields[key].kwarg_model).header:
source = "header"
else:
source = "query"
message["source"] = cast("Literal['cookie', 'body', 'header', 'query']", source)
message: ErrorMessage = {"message": exc_msg.split(" - ")[0]}

if not keys:
return message

message["key"] = key = ".".join(keys)

if key in connection.query_params:
message["source"] = "query"
elif key in cls.fields and isinstance(cls.fields[key].kwarg_model, ParameterKwarg):
kwarg_model = cast("ParameterKwarg", cls.fields[key].kwarg_model)
if kwarg_model.cookie:
message["source"] = "cookie"
elif kwarg_model.header:
message["source"] = "header"
else:
message["source"] = "query"
else:
message["source"] = "body"

return message

Expand Down
32 changes: 21 additions & 11 deletions litestar/_signature/models/msgspec_signature_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from msgspec import NODEFAULT, Meta, Struct, ValidationError, convert, defstruct
from msgspec.structs import asdict
from pydantic import ValidationError as PydanticValidationError
from typing_extensions import Annotated

from litestar._signature.field import SignatureField
Expand All @@ -17,6 +18,14 @@

from .base import SignatureModel

if TYPE_CHECKING:
from litestar.connection import ASGIConnection
from litestar.utils.signature import ParsedSignature

from .base import ErrorMessage

__all__ = ("MsgspecSignatureModel",)

MSGSPEC_CONSTRAINT_FIELDS = (
"gt",
"ge",
Expand All @@ -28,13 +37,6 @@
"max_length",
)

if TYPE_CHECKING:
from litestar.connection import ASGIConnection
from litestar.utils.signature import ParsedSignature

__all__ = ("MsgspecSignatureModel",)


ERR_RE = re.compile(r"`\$\.(.+)`$")


Expand All @@ -56,13 +58,21 @@ def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, **kwarg
Returns:
A dictionary of parsed values
"""
messages: list[ErrorMessage] = []
try:
return convert(kwargs, cls, strict=False, dec_hook=dec_hook).to_dict()
except PydanticValidationError as e:
for exc in e.errors():
keys = [str(loc) for loc in exc["loc"]]
message = super()._build_error_message(keys=keys, exc_msg=exc["msg"], connection=connection)
messages.append(message)
raise cls._create_exception(messages=messages, connection=connection) from e
except ValidationError as e:
message = str(e)
match = ERR_RE.search(message)
key = str(match.group(1)) if match else "n/a"
raise cls._create_exception(messages=[{"key": key, "message": message}], connection=connection) from e
match = ERR_RE.search(str(e))
keys = [str(match.group(1)) if match else "n/a"]
message = super()._build_error_message(keys=keys, exc_msg=str(e), connection=connection)
messages.append(message)
raise cls._create_exception(messages=messages, connection=connection) from e

def to_dict(self) -> dict[str, Any]:
"""Normalize access to the signature model's dictionary method, because different backends use different methods
Expand Down
6 changes: 1 addition & 5 deletions litestar/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
SecretField,
StrictBool,
)
from pydantic import ValidationError as PydanticValidationError
from pydantic.color import Color
from pydantic.json import decimal_encoder

Expand Down Expand Up @@ -155,10 +154,7 @@ def _dec_pydantic_uuid(type_: type[PydanticUUIDType], val: Any) -> PydanticUUIDT


def _dec_pydantic(type_: type[BaseModel], value: Any) -> BaseModel:
try:
return type_.parse_obj(value)
except PydanticValidationError as e:
raise ValidationError from e
return type_.parse_obj(value)


def dec_hook(type_: Any, value: Any) -> Any: # pragma: no cover
Expand Down
20 changes: 10 additions & 10 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ importlib-resources = { version = ">=5.12.0", python = "<3.9" }
jinja2 = { version = ">=3.1.2", optional = true }
jsbeautifier = { version = "*", optional = true }
mako = { version = ">=1.2.4", optional = true }
msgspec = { git = "https://github.com/jcrist/msgspec.git", rev = "a1d43b35dfa93203e71bec9999958cdd0cecb593" }
msgspec = { git = "https://github.com/jcrist/msgspec.git", rev = "e793b5089befcccdc4e9ec8ec558b9e4067e68b0" }
multidict = ">=6.0.2"
opentelemetry-instrumentation-asgi = { version = "*", optional = true }
picologging = { version = "*", optional = true }
Expand Down
52 changes: 12 additions & 40 deletions tests/unit/test_signature/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,51 +137,27 @@ def test(dep: int, param: int, optional_dep: Optional[int] = Dependency()) -> No
) as client:
response = client.get("/?param=13")

assert response.json() == {
"detail": "Internal Server Error",
"extra": [{"key": "dep", "message": error_message}],
"status_code": HTTP_500_INTERNAL_SERVER_ERROR,
}
assert response.json() == {"detail": "Internal Server Error", "status_code": HTTP_500_INTERNAL_SERVER_ERROR}


@pytest.mark.parametrize(
"preferred_validation_backend, error_message",
(
pytest.param("attrs", "invalid literal for int() with base 10: 'thirteen'", marks=pytest.mark.skip, id="attrs"),
pytest.param("msgspec", "Expected `int`, got `str` - at `$.param`", id="msgspec"),
),
)
def test_validation_failure_raises_400(
preferred_validation_backend: Literal["attrs", "pydantic"], error_message: Any
) -> None:
def test_validation_failure_raises_400() -> None:
dependencies = {"dep": Provide(lambda: 13, sync_to_thread=False)}

@get("/")
def test(dep: int, param: int, optional_dep: Optional[int] = Dependency()) -> None:
...

with create_test_client(
route_handlers=[test], dependencies=dependencies, _preferred_validation_backend=preferred_validation_backend
) as client:
with create_test_client(route_handlers=[test], dependencies=dependencies) as client:
response = client.get("/?param=thirteen")

assert response.json() == {
"detail": "Validation failed for GET http://testserver.local/?param=thirteen",
"extra": [{"key": "param", "message": error_message}],
"extra": [{"key": "param", "message": "Expected `int`, got `str`", "source": "query"}],
"status_code": 400,
}


@pytest.mark.parametrize(
"backend,expected_error_msg",
[
pytest.param(
"pydantic", "invalid literal for int() with base 10: 'thirteen'", marks=pytest.mark.skip, id="pydantic"
),
pytest.param("msgspec", "Expected `int`, got `str` - at `$.param`", id="msgspec"),
],
)
def test_client_backend_error_precedence_over_server_error(backend: str, expected_error_msg: Any) -> None:
def test_client_backend_error_precedence_over_server_error() -> None:
dependencies = {
"dep": Provide(lambda: "thirteen", sync_to_thread=False),
"optional_dep": Provide(lambda: "thirty-one", sync_to_thread=False),
Expand All @@ -191,16 +167,12 @@ def test_client_backend_error_precedence_over_server_error(backend: str, expecte
def test(dep: int, param: int, optional_dep: Optional[int] = Dependency()) -> None:
...

with create_test_client(
route_handlers=[test],
dependencies=dependencies,
_preferred_validation_backend=backend, # type: ignore[arg-type]
) as client:
with create_test_client(route_handlers=[test], dependencies=dependencies) as client:
response = client.get("/?param=thirteen")

assert response.json() == {
"detail": "Validation failed for GET http://testserver.local/?param=thirteen",
"extra": [{"key": "param", "message": expected_error_msg}],
"extra": [{"key": "param", "message": "Expected `int`, got `str`", "source": "query"}],
"status_code": 400,
}

Expand Down Expand Up @@ -305,7 +277,7 @@ class Parent(BaseModel):
child: Child
other_child: OtherChild

def fn(model: Parent) -> None:
def fn(data: Parent) -> None:
pass

model = create_signature_model(
Expand All @@ -317,13 +289,13 @@ def fn(model: Parent) -> None:

with pytest.raises(ValidationException) as exc_info:
model.parse_values_from_connection_kwargs(
connection=RequestFactory().get(), model={"child": {}, "other_child": {}}
connection=RequestFactory().get(), data={"child": {}, "other_child": {}}
)

assert isinstance(exc_info.value.extra, list)
assert exc_info.value.extra[0]["key"] == "model.child.val"
assert exc_info.value.extra[1]["key"] == "model.child.other_val"
assert exc_info.value.extra[2]["key"] == "model.other_child.val"
assert exc_info.value.extra[0]["key"] == "child.val"
assert exc_info.value.extra[1]["key"] == "child.other_val"
assert exc_info.value.extra[2]["key"] == "other_child.val"


def test_invalid_input_pydantic() -> None:
Expand Down

0 comments on commit b7cb2e2

Please sign in to comment.