Skip to content

Commit

Permalink
feat: msgspec signature model.
Browse files Browse the repository at this point in the history
Linting fixes.

Update tests/unit/test_kwargs/test_path_params.py

fix signature namespace issue

support min_length and max_length

support min_length and max_length

handle constraints on union types

fix error message tests
  • Loading branch information
peterschutt authored and Goldziher committed Jul 5, 2023
1 parent a1b7753 commit 2b43b07
Show file tree
Hide file tree
Showing 10 changed files with 342 additions and 263 deletions.
14 changes: 7 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ repos:
- id: unasyncd
additional_dependencies: ["ruff"]
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.0.276"
rev: "v0.0.277"
hooks:
- id: ruff
args: ["--fix"]
Expand Down Expand Up @@ -65,7 +65,7 @@ repos:
exclude: "test_apps|tools|docs|tests/examples|tests/docker_service_fixtures"
additional_dependencies:
[
"polyfactory>=2.3.2",
"git+https://github.com/jcrist/msgspec.git",
aiosqlite,
annotated_types,
async_timeout,
Expand All @@ -87,13 +87,14 @@ repos:
jsbeautifier,
mako,
mongomock_motor,
msgspec,
multidict,
opentelemetry-instrumentation-asgi,
opentelemetry-sdk,
oracledb,
piccolo,
picologging,
polyfactory,
prometheus_client,
psycopg,
pydantic,
pytest,
Expand All @@ -116,7 +117,6 @@ repos:
types-pyyaml,
types-redis,
uvicorn,
prometheus_client,
]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.316
Expand All @@ -125,7 +125,7 @@ repos:
exclude: "test_apps|tools|docs|_openapi|tests/examples|tests/docker_service_fixtures"
additional_dependencies:
[
"polyfactory>=2.3.2",
"git+https://github.com/jcrist/msgspec.git",
aiosqlite,
annotated_types,
async_timeout,
Expand All @@ -147,13 +147,14 @@ repos:
jsbeautifier,
mako,
mongomock_motor,
msgspec,
multidict,
opentelemetry-instrumentation-asgi,
opentelemetry-sdk,
oracledb,
piccolo,
picologging,
polyfactory,
prometheus_client,
psycopg,
pydantic,
pytest,
Expand All @@ -176,7 +177,6 @@ repos:
types-pyyaml,
types-redis,
uvicorn,
prometheus_client,
]
- repo: local
hooks:
Expand Down
63 changes: 24 additions & 39 deletions litestar/_signature/models/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# ruff: noqa: UP006
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Sequence, TypedDict, cast
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Sequence, Set, TypedDict, cast

from litestar.enums import ScopeType
from litestar.exceptions import InternalServerException, ValidationException
Expand All @@ -14,7 +15,10 @@
from litestar.typing import FieldDefinition
from litestar.utils.signature import ParsedSignature

__all__ = ("SignatureModel",)
__all__ = (
"ErrorMessage",
"SignatureModel",
)


class ErrorMessage(TypedDict):
Expand All @@ -26,10 +30,10 @@ class ErrorMessage(TypedDict):
source: NotRequired[Literal["cookie", "body", "header", "query"]]


class SignatureModel(ABC):
class SignatureModel:
"""Base model for Signature modelling."""

dependency_name_set: ClassVar[set[str]]
dependency_name_set: ClassVar[Set[str]]
return_annotation: ClassVar[Any]
fields: ClassVar[dict[str, FieldDefinition]]

Expand All @@ -52,9 +56,7 @@ def _create_exception(cls, connection: ASGIConnection, messages: list[ErrorMessa
if ("key" in err_message and err_message["key"] not in cls.dependency_name_set) or "key" not in err_message
]:
return ValidationException(detail=f"Validation failed for {method} {connection.url}", extra=client_errors)
return InternalServerException(
detail=f"A dependency failed validation for {method} {connection.url}", extra=messages
)
return InternalServerException()

@classmethod
def _build_error_message(cls, keys: Sequence[str], exc_msg: str, connection: ASGIConnection) -> ErrorMessage:
Expand All @@ -69,31 +71,24 @@ def _build_error_message(cls, keys: Sequence[str], exc_msg: str, connection: ASG
An ErrorMessage
"""

message: ErrorMessage = {"message": exc_msg}
message: ErrorMessage = {"message": exc_msg.split(" - ")[0]}

if len(keys) > 1:
key_start = 0
if not keys:
return message

if keys[0] == "data":
key_start = 1
message["source"] = "body"
message["key"] = key = ".".join(keys)

message["key"] = ".".join(keys[key_start:])
elif keys:
key = keys[0]
message["key"] = key
if key in connection.query_params:
message["source"] = cast("Literal['cookie', 'body', 'header', 'query']", "query")

if key in connection.query_params:
message["source"] = cast("Literal['cookie', 'body', 'header', 'query']", "query")

elif key in cls.fields and isinstance(cls.fields[key].kwarg_definition, ParameterKwarg):
if cast(ParameterKwarg, cls.fields[key].kwarg_definition).cookie:
source = "cookie"
elif cast(ParameterKwarg, cls.fields[key].kwarg_definition).header:
source = "header"
else:
source = "query"
message["source"] = cast("Literal['cookie', 'body', 'header', 'query']", source)
elif key in cls.fields and isinstance(cls.fields[key].kwarg_definition, ParameterKwarg):
if cast(ParameterKwarg, cls.fields[key].kwarg_definition).cookie:
source = "cookie"
elif cast(ParameterKwarg, cls.fields[key].kwarg_definition).header:
source = "header"
else:
source = "query"
message["source"] = cast("Literal['cookie', 'body', 'header', 'query']", source)

return message

Expand Down Expand Up @@ -124,16 +119,6 @@ def to_dict(self) -> dict[str, Any]:
"""
raise NotImplementedError

@classmethod
@abstractmethod
def populate_field_definitions(cls) -> None:
"""Populate the class signature fields.
Returns:
None.
"""
raise NotImplementedError

@classmethod
@abstractmethod
def create(
Expand Down
136 changes: 136 additions & 0 deletions litestar/_signature/models/msgspec_signature_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from __future__ import annotations

import re
from typing import TYPE_CHECKING, Any, Optional, Union

from msgspec import NODEFAULT, Meta, Struct, ValidationError, convert, defstruct
from msgspec.structs import asdict
from pydantic import ValidationError as PydanticValidationError
from typing_extensions import Annotated

from litestar.params import DependencyKwarg, KwargDefinition
from litestar.serialization import dec_hook
from litestar.utils import make_non_optional_union
from litestar.utils.dataclass import simple_asdict
from litestar.utils.typing import unwrap_union

from .base import SignatureModel

if TYPE_CHECKING:
from litestar.connection import ASGIConnection
from litestar.utils.signature import ParsedSignature

from .base import ErrorMessage

__all__ = ("MsgspecSignatureModel",)

MSGSPEC_CONSTRAINT_FIELDS = (
"gt",
"ge",
"lt",
"le",
"multiple_of",
"pattern",
"min_length",
"max_length",
)

ERR_RE = re.compile(r"`\$\.(.+)`$")


class MsgspecSignatureModel(SignatureModel, Struct):
"""Model that represents a function signature that uses a msgspec specific type or types."""

@classmethod
def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, **kwargs: Any) -> dict[str, Any]:
"""Extract values from the connection instance and return a dict of parsed values.
Args:
connection: The ASGI connection instance.
**kwargs: A dictionary of kwargs.
Raises:
ValidationException: If validation failed.
InternalServerException: If another exception has been raised.
Returns:
A dictionary of parsed values
"""
messages: list[ErrorMessage] = []
try:
return convert(kwargs, cls, strict=False, dec_hook=dec_hook).to_dict()
except PydanticValidationError as e:
for exc in e.errors():
keys = [str(loc) for loc in exc["loc"]]
message = super()._build_error_message(keys=keys, exc_msg=exc["msg"], connection=connection)
messages.append(message)
raise cls._create_exception(messages=messages, connection=connection) from e
except ValidationError as e:
match = ERR_RE.search(str(e))
keys = [str(match.group(1)) if match else "n/a"]
message = super()._build_error_message(keys=keys, exc_msg=str(e), connection=connection)
messages.append(message)
raise cls._create_exception(messages=messages, connection=connection) from e

def to_dict(self) -> dict[str, Any]:
"""Normalize access to the signature model's dictionary method, because different backends use different methods
for this.
Returns: A dictionary of string keyed values.
"""
return asdict(self)

@classmethod
def create(
cls,
fn_name: str,
fn_module: str | None,
parsed_signature: ParsedSignature,
dependency_names: set[str],
type_overrides: dict[str, Any],
) -> type[SignatureModel]:
struct_fields: list[tuple[str, Any, Any]] = []

for field_definition in parsed_signature.parameters.values():
annotation = type_overrides.get(field_definition.name, field_definition.annotation)

meta_kwargs: dict[str, Any] = {}

if isinstance(field_definition.kwarg_definition, KwargDefinition):
meta_kwargs.update(
{k: v for k in MSGSPEC_CONSTRAINT_FIELDS if (v := getattr(field_definition.kwarg_definition, k))}
)
meta_kwargs["extra"] = simple_asdict(field_definition.kwarg_definition)
elif isinstance(field_definition.kwarg_definition, DependencyKwarg):
annotation = annotation if not field_definition.kwarg_definition.skip_validation else Any

default = field_definition.default if field_definition.has_default else NODEFAULT

meta = Meta(**meta_kwargs)
if field_definition.is_optional:
annotated_type = Optional[Annotated[make_non_optional_union(field_definition.annotation), meta]] # type: ignore
elif field_definition.is_union and meta_kwargs.keys() & MSGSPEC_CONSTRAINT_FIELDS:
# unwrap inner types of a union and apply constraints to each individual type
# see https://github.com/jcrist/msgspec/issues/447
annotated_type = Union[
tuple(
Annotated[inner_type, meta] for inner_type in unwrap_union(field_definition.annotation)
) # pyright: ignore
]
else:
annotated_type = Annotated[annotation, meta]

struct_fields.append((field_definition.name, annotated_type, default))

return defstruct( # type:ignore[return-value]
f"{fn_name}_signature_model",
struct_fields,
bases=(MsgspecSignatureModel,),
module=fn_module,
namespace={
"return_annotation": parsed_signature.return_type.annotation,
"dependency_name_set": dependency_names,
"fields": parsed_signature.parameters,
},
kw_only=True,
)
8 changes: 2 additions & 6 deletions litestar/_signature/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pydantic

from litestar._signature.models.pydantic_signature_model import PydanticSignatureModel
from litestar._signature.models.msgspec_signature_model import MsgspecSignatureModel
from litestar.constants import SKIP_VALIDATION_NAMES
from litestar.exceptions import ImproperlyConfiguredException
from litestar.params import DependencyKwarg
Expand Down Expand Up @@ -111,11 +111,7 @@ def _get_signature_model_type(
preferred_validation_backend: Literal["pydantic", "attrs"],
parsed_signature: ParsedSignature,
) -> type[SignatureModel]:
if preferred_validation_backend == "attrs" or _any_attrs_annotation(parsed_signature):
from litestar._signature.models.attrs_signature_model import AttrsSignatureModel

return AttrsSignatureModel
return PydanticSignatureModel
return MsgspecSignatureModel


def _should_skip_validation(field_definition: FieldDefinition) -> bool:
Expand Down
Loading

0 comments on commit 2b43b07

Please sign in to comment.