Skip to content

Commit

Permalink
feat(internal): handle multiple messages for msgspec
Browse files Browse the repository at this point in the history
  • Loading branch information
Goldziher committed Jul 15, 2023
1 parent 6a7af34 commit f22961c
Show file tree
Hide file tree
Showing 16 changed files with 260 additions and 192 deletions.
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ repos:
exclude: "test_apps|tools|docs|tests/examples|tests/docker_service_fixtures"
additional_dependencies:
[
msgspec>=0.17.0,
polyfactory,
aiosqlite,
annotated_types,
async_timeout,
Expand All @@ -91,14 +91,13 @@ repos:
jsbeautifier,
mako,
mongomock_motor,
msgspec,
multidict,
opentelemetry-instrumentation-asgi,
opentelemetry-sdk,
oracledb,
piccolo,
picologging,
polyfactory,
prometheus_client,
psycopg,
pydantic>=2,
pydantic_extra_types,
Expand All @@ -122,6 +121,7 @@ repos:
types-pyyaml,
types-redis,
uvicorn,
prometheus_client,
]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.317
Expand All @@ -130,7 +130,7 @@ repos:
exclude: "test_apps|tools|docs|_openapi|tests/examples|tests/docker_service_fixtures"
additional_dependencies:
[
msgspec>=0.17.0,
polyfactory,
aiosqlite,
annotated_types,
async_timeout,
Expand All @@ -151,14 +151,13 @@ repos:
jsbeautifier,
mako,
mongomock_motor,
msgspec,
multidict,
opentelemetry-instrumentation-asgi,
opentelemetry-sdk,
oracledb,
piccolo,
picologging,
polyfactory,
prometheus_client,
psycopg,
pydantic>=2,
pydantic_extra_types,
Expand All @@ -182,6 +181,7 @@ repos:
types-pyyaml,
types-redis,
uvicorn,
prometheus_client,
]
- repo: local
hooks:
Expand Down
2 changes: 1 addition & 1 deletion litestar/_kwargs/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def resolve_dependency(
signature_model = get_signature_model(dependency.provide)
dependency_kwargs = (
signature_model.parse_values_from_connection_kwargs(connection=connection, **kwargs)
if signature_model.fields
if signature_model._fields
else {}
)
value = await dependency.provide(**dependency_kwargs)
Expand Down
4 changes: 2 additions & 2 deletions litestar/_kwargs/kwargs_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def create_for_signature_model(
An instance of KwargsModel
"""

field_definitions = signature_model.fields
field_definitions = signature_model._fields

cls._validate_raw_kwargs(
path_parameters=path_parameters,
Expand Down Expand Up @@ -405,7 +405,7 @@ def _create_dependency_graph(cls, key: str, dependencies: dict[str, Provide]) ->
list.
"""
provide = dependencies[key]
sub_dependency_keys = [k for k in get_signature_model(provide).fields if k in dependencies]
sub_dependency_keys = [k for k in get_signature_model(provide)._fields if k in dependencies]
return Dependency(
key=key,
provide=provide,
Expand Down
2 changes: 1 addition & 1 deletion litestar/_openapi/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def get_recursive_handler_parameters(
)
]

dependency_fields = dependency_providers[field_name].signature_model.fields
dependency_fields = dependency_providers[field_name].signature_model._fields
return create_parameter_for_handler(
route_handler=route_handler,
handler_fields=dependency_fields,
Expand Down
2 changes: 1 addition & 1 deletion litestar/_openapi/path_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def create_path_item(
route_handler, _ = handler_tuple

if route_handler.include_in_schema:
handler_fields = route_handler.signature_model.fields if route_handler.signature_model else {}
handler_fields = route_handler.signature_model._fields if route_handler.signature_model else {}
parameters = (
create_parameter_for_handler(
route_handler=route_handler,
Expand Down
66 changes: 41 additions & 25 deletions litestar/_signature/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from typing_extensions import Annotated

from litestar._signature.utils import create_type_overrides, validate_signature_dependencies
from litestar.enums import ScopeType
from litestar.enums import ParamType, ScopeType
from litestar.exceptions import InternalServerException, ValidationException
from litestar.params import DependencyKwarg, KwargDefinition, ParameterKwarg
from litestar.serialization import ExtendedMsgSpecValidationError, dec_hook
from litestar.serialization import dec_hook
from litestar.serialization._msgspec_utils import ExtendedMsgSpecValidationError
from litestar.typing import FieldDefinition # noqa
from litestar.utils import make_non_optional_union
from litestar.utils.dataclass import simple_asdict
Expand All @@ -38,7 +39,7 @@ class ErrorMessage(TypedDict):
# in this case, we don't show a key at all as it will be empty
key: NotRequired[str]
message: str
source: NotRequired[Literal["cookie", "body", "header", "query"]]
source: NotRequired[Literal["body"] | ParamType]


MSGSPEC_CONSTRAINT_FIELDS = (
Expand All @@ -59,9 +60,9 @@ class SignatureModel(Struct):
"""Model that represents a function signature that uses a msgspec specific type or types."""

# NOTE: we have to use Set and Dict here because python 3.8 goes haywire if we use 'set' and 'dict'
dependency_name_set: ClassVar[Set[str]]
fields: ClassVar[Dict[str, FieldDefinition]]
return_annotation: ClassVar[Any]
_dependency_name_set: ClassVar[Set[str]]
_fields: ClassVar[Dict[str, FieldDefinition]]
_return_annotation: ClassVar[Any]

@classmethod
def _create_exception(cls, connection: ASGIConnection, messages: list[ErrorMessage]) -> Exception:
Expand All @@ -79,7 +80,7 @@ def _create_exception(cls, connection: ASGIConnection, messages: list[ErrorMessa
if client_errors := [
err_message
for err_message in messages
if ("key" in err_message and err_message["key"] not in cls.dependency_name_set) or "key" not in err_message
if ("key" in err_message and err_message["key"] not in cls._dependency_name_set) or "key" not in err_message
]:
return ValidationException(detail=f"Validation failed for {method} {connection.url}", extra=client_errors)
return InternalServerException()
Expand All @@ -103,21 +104,35 @@ def _build_error_message(cls, keys: Sequence[str], exc_msg: str, connection: ASG
return message

message["key"] = key = ".".join(keys)

if key in connection.query_params:
message["source"] = cast("Literal['cookie', 'body', 'header', 'query']", "query")

elif key in cls.fields and isinstance(cls.fields[key].kwarg_definition, ParameterKwarg):
if cast(ParameterKwarg, cls.fields[key].kwarg_definition).cookie:
source = "cookie"
elif cast(ParameterKwarg, cls.fields[key].kwarg_definition).header:
source = "header"
if keys[0].startswith("data"):
message["key"] = message["key"].replace("data.", "")
message["source"] = "body"
elif key in connection.query_params:
message["source"] = ParamType.QUERY

elif key in cls._fields and isinstance(cls._fields[key].kwarg_definition, ParameterKwarg):
if cast(ParameterKwarg, cls._fields[key].kwarg_definition).cookie:
message["source"] = ParamType.COOKIE
elif cast(ParameterKwarg, cls._fields[key].kwarg_definition).header:
message["source"] = ParamType.HEADER
else:
source = "query"
message["source"] = cast("Literal['cookie', 'body', 'header', 'query']", source)
message["source"] = ParamType.QUERY

return message

@classmethod
def _collect_errors(cls, **kwargs: Any) -> list[tuple[str, Exception]]:
exceptions: list[tuple[str, Exception]] = []
for field_name in cls._fields:
try:
raw_value = kwargs[field_name]
annotation = cls.__annotations__[field_name]
convert(raw_value, type=annotation, strict=False, dec_hook=dec_hook, str_keys=True)
except Exception as e: # noqa: BLE001
exceptions.append((field_name, e))

return exceptions

@classmethod
def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, **kwargs: Any) -> dict[str, Any]:
"""Extract values from the connection instance and return a dict of parsed values.
Expand All @@ -143,10 +158,11 @@ def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, **kwarg
messages.append(message)
raise cls._create_exception(messages=messages, connection=connection) from e
except ValidationError as e:
match = ERR_RE.search(str(e))
keys = [str(match.group(1)) if match else "n/a"]
message = cls._build_error_message(keys=keys, exc_msg=str(e), connection=connection)
messages.append(message)
for field_name, exc in cls._collect_errors(**kwargs): # type: ignore[assignment]
match = ERR_RE.search(str(exc))
keys = [field_name, str(match.group(1))] if match else [field_name]
message = cls._build_error_message(keys=keys, exc_msg=str(e), connection=connection)
messages.append(message)
raise cls._create_exception(messages=messages, connection=connection) from e

def to_dict(self) -> dict[str, Any]:
Expand Down Expand Up @@ -221,9 +237,9 @@ def create(
bases=(cls,),
module=getattr(fn, "__module__", None),
namespace={
"return_annotation": parsed_signature.return_type.annotation,
"dependency_name_set": dependency_names,
"fields": parsed_signature.parameters,
"_return_annotation": parsed_signature.return_type.annotation,
"_dependency_name_set": dependency_names,
"_fields": parsed_signature.parameters,
},
kw_only=True,
)
21 changes: 21 additions & 0 deletions litestar/serialization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from .msgspec_hooks import (
dec_hook,
decode_json,
decode_media_type,
decode_msgpack,
default_serializer,
encode_json,
encode_msgpack,
get_serializer,
)

__all__ = (
"dec_hook",
"decode_json",
"decode_media_type",
"decode_msgpack",
"default_serializer",
"encode_json",
"encode_msgpack",
"get_serializer",
)
9 changes: 9 additions & 0 deletions litestar/serialization/_msgspec_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Any

from msgspec import ValidationError


class ExtendedMsgSpecValidationError(ValidationError):
def __init__(self, errors: list[dict[str, Any]]) -> None:
self.errors = errors
super().__init__(errors)
105 changes: 105 additions & 0 deletions litestar/serialization/_pydantic_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from __future__ import annotations

from typing import Any, Callable, TypeVar, cast
from uuid import UUID

from msgspec import ValidationError

from litestar.serialization._msgspec_utils import ExtendedMsgSpecValidationError
from litestar.utils import is_class_and_subclass, is_pydantic_model_class

__all__ = (
"create_pydantic_decoders",
"create_pydantic_encoders",
)

T = TypeVar("T")


def create_pydantic_decoders() -> list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]]:
decoders: list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]] = []
try:
import pydantic

def _dec_pydantic(type_: type[pydantic.BaseModel], value: Any) -> pydantic.BaseModel:
try:
return (
type_.model_validate(value, strict=False)
if hasattr(type_, "model_validate")
else type_.parse_obj(value)
)
except pydantic.ValidationError as e:
raise ExtendedMsgSpecValidationError(errors=cast(list[dict[str, Any]], e.errors())) from e

decoders.append((is_pydantic_model_class, _dec_pydantic))

def _dec_pydantic_uuid(
type_: type[pydantic.UUID1] | type[pydantic.UUID3] | type[pydantic.UUID4] | type[pydantic.UUID5], val: Any
) -> type[pydantic.UUID1] | type[pydantic.UUID3] | type[pydantic.UUID4] | type[pydantic.UUID5]:
if isinstance(val, str):
val = type_(val)

elif isinstance(val, (bytes, bytearray)):
try:
val = type_(val.decode())
except ValueError:
# 16 bytes in big-endian order as the bytes argument fail
# the above check
val = type_(bytes=val)
elif isinstance(val, UUID):
val = type_(str(val))

if not isinstance(val, type_):
raise ValidationError(f"Invalid UUID: {val!r}")

if type_._required_version != val.version: # type: ignore
raise ValidationError(f"Invalid UUID version: {val!r}")

return cast(
"type[pydantic.UUID1] | type[pydantic.UUID3] | type[pydantic.UUID4] | type[pydantic.UUID5]", val
)

def _is_pydantic_uuid(value: Any) -> bool:
return is_class_and_subclass(value, (pydantic.UUID1, pydantic.UUID3, pydantic.UUID4, pydantic.UUID5))

decoders.append((_is_pydantic_uuid, _dec_pydantic_uuid))
return decoders
except ImportError:
return decoders


def create_pydantic_encoders() -> dict[Any, Callable[[Any], Any]]:
try:
import pydantic

encoders: dict[Any, Callable[[Any], Any]] = {
pydantic.EmailStr: str,
pydantic.NameEmail: str,
pydantic.ByteSize: lambda val: val.real,
}

if pydantic.VERSION.startswith("1"): # pragma: no cover
encoders.update(
{
pydantic.BaseModel: lambda model: model.dict(),
pydantic.SecretField: str,
pydantic.StrictBool: int,
pydantic.color.Color: str, # pyright: ignore
pydantic.ConstrainedBytes: lambda val: val.decode("utf-8"),
pydantic.ConstrainedDate: lambda val: val.isoformat(),
}
)
else:
from pydantic_extra_types import color

encoders.update(
{
pydantic.BaseModel: lambda model: model.model_dump(mode="json"),
color.Color: str,
pydantic.types.SecretStr: lambda val: "**********" if val else "",
pydantic.types.SecretBytes: lambda val: "**********" if val else "",
}
)
return encoders
except ImportError:
return {}
Loading

0 comments on commit f22961c

Please sign in to comment.