Skip to content

Commit

Permalink
RFC: msgspec signature model.
Browse files Browse the repository at this point in the history
  • Loading branch information
peterschutt committed Jun 21, 2023
1 parent 35bf05f commit 7dd6301
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 17 deletions.
12 changes: 6 additions & 6 deletions litestar/_signature/models/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Set, TypedDict

from litestar._signature.field import SignatureField
from litestar.enums import ScopeType
from litestar.exceptions import InternalServerException, ValidationException

if TYPE_CHECKING:
from litestar._signature.field import SignatureField
from litestar.connection import ASGIConnection
from litestar.utils.signature import ParsedSignature

Expand All @@ -19,12 +19,12 @@ class ErrorMessage(TypedDict):
message: str


class SignatureModel(ABC):
class SignatureModel:
"""Base model for Signature modelling."""

dependency_name_set: ClassVar[set[str]]
dependency_name_set: ClassVar[Set[str]]
return_annotation: ClassVar[Any]
fields: ClassVar[dict[str, SignatureField]]
fields: ClassVar[Dict[str, SignatureField]]

@classmethod
def _create_exception(cls, connection: ASGIConnection, messages: list[ErrorMessage]) -> Exception:
Expand Down
132 changes: 132 additions & 0 deletions litestar/_signature/models/msgspec_signature_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from __future__ import annotations

import re
from typing import TYPE_CHECKING, Any

from msgspec import NODEFAULT, Meta, Struct, ValidationError, convert, defstruct
from msgspec.structs import asdict
from typing_extensions import Annotated

from litestar._signature.field import SignatureField
from litestar.params import DependencyKwarg
from litestar.types.empty import Empty
from litestar.utils.dataclass import simple_asdict
from litestar.serialization import dec_hook

from .base import SignatureModel

if TYPE_CHECKING:
from litestar.connection import ASGIConnection
from litestar.utils.signature import ParsedSignature

__all__ = ("MsgspecSignatureModel",)


ERR_RE = re.compile(r"`\$\.(.+)`$")


class MsgspecSignatureModel(SignatureModel, Struct):
"""Model that represents a function signature that uses a msgspec specific type or types."""

@classmethod
def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, **kwargs: Any) -> dict[str, Any]:
"""Extract values from the connection instance and return a dict of parsed values.
Args:
connection: The ASGI connection instance.
**kwargs: A dictionary of kwargs.
Raises:
ValidationException: If validation failed.
InternalServerException: If another exception has been raised.
Returns:
A dictionary of parsed values
"""
try:
return convert(kwargs, cls, strict=False, dec_hook=dec_hook).to_dict()
except ValidationError as e:
message = str(e)
match = ERR_RE.search(message)
key = str(match.group(1)) if match else "n/a"
raise cls._create_exception(messages=[{"key": key, "message": message}], connection=connection) from e

def to_dict(self) -> dict[str, Any]:
"""Normalize access to the signature model's dictionary method, because different backends use different methods
for this.
Returns: A dictionary of string keyed values.
"""
return asdict(self)

@classmethod
def populate_signature_fields(cls) -> None:
...

@classmethod
def create(
cls,
fn_name: str,
fn_module: str | None,
parsed_signature: ParsedSignature,
dependency_names: set[str],
type_overrides: dict[str, Any],
) -> type[SignatureModel]:
struct_fields: list[tuple[str, Any, Any]] = []
signature_fields: dict[str, SignatureField] = {}

for parameter in parsed_signature.parameters.values():
annotation = type_overrides.get(parameter.name, parameter.parsed_type.annotation)

field_extra: dict[str, Any] = {"parsed_parameter": parameter}
meta_kwargs: dict[str, Any] = {}

if kwargs_container := parameter.kwarg_container:
field_extra["kwargs_model"] = kwargs_container
if isinstance(kwargs_container, DependencyKwarg):
annotation = annotation if not kwargs_container.skip_validation else Any
default = kwargs_container.default if kwargs_container.default is not Empty else NODEFAULT
else:
param_dict = simple_asdict(kwargs_container)
field_extra.update(param_dict)
meta_kwargs = {
k: v
for k in (
"gt",
"ge",
"lt",
"le",
"multiple_of",
"pattern",
"min_length",
"max_length",
)
if (v := getattr(kwargs_container, k))
}
default = NODEFAULT
else:
default = parameter.default if parameter.has_default else NODEFAULT

struct_fields.append(
(parameter.name, Annotated[annotation, Meta(extra=field_extra, **meta_kwargs)], default)
)
signature_fields[parameter.name] = SignatureField.create(
field_type=annotation,
name=parameter.name,
default_value=Empty if default is NODEFAULT else default,
kwarg_model=kwargs_container,
extra=field_extra,
)

return defstruct( # type:ignore[return-value]
f"{fn_name}_signature_model",
struct_fields,
bases=(MsgspecSignatureModel,),
module=fn_module or "",
namespace={
"return_annotation": parsed_signature.return_type.annotation,
"dependency_name_set": dependency_names,
"fields": signature_fields,
},
kw_only=True,
)
8 changes: 2 additions & 6 deletions litestar/_signature/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pydantic

from litestar._signature.models.pydantic_signature_model import PydanticSignatureModel
from litestar._signature.models.msgspec_signature_model import MsgspecSignatureModel
from litestar.constants import SKIP_VALIDATION_NAMES
from litestar.exceptions import ImproperlyConfiguredException
from litestar.params import DependencyKwarg
Expand Down Expand Up @@ -110,11 +110,7 @@ def _get_signature_model_type(
preferred_validation_backend: Literal["pydantic", "attrs"],
parsed_signature: ParsedSignature,
) -> type[SignatureModel]:
if preferred_validation_backend == "attrs" or _any_attrs_annotation(parsed_signature):
from litestar._signature.models.attrs_signature_model import AttrsSignatureModel

return AttrsSignatureModel
return PydanticSignatureModel
return MsgspecSignatureModel


def _should_skip_validation(parameter: ParsedParameter) -> bool:
Expand Down
8 changes: 7 additions & 1 deletion litestar/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pathlib import Path, PurePath
from re import Pattern
from typing import TYPE_CHECKING, Any, Callable, Mapping, TypeVar, overload
from uuid import UUID

import msgspec
from pydantic import (
Expand Down Expand Up @@ -135,9 +136,14 @@ def dec_hook(type_: Any, value: Any) -> Any: # pragma: no cover
Returns:
A ``msgspec``-supported type
"""

from litestar.datastructures.state import State

if isinstance(value, type_):
return value
if issubclass(type_, BaseModel):
return type_.parse_obj(value)
if issubclass(type_, (Path, PurePath)):
if issubclass(type_, (Path, PurePath, State, UUID)):
return type_(value)
raise TypeError(f"Unsupported type: {type(value)!r}")

Expand Down
6 changes: 3 additions & 3 deletions tests/e2e/test_routing/test_path_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,9 @@ def test_support_for_path_type_parameters() -> None:
def lower_handler(string_param: str) -> str:
return string_param

@get(path="/{string_param:str}/{parth_param:path}")
def upper_handler(string_param: str, parth_param: Path) -> str:
return string_param + str(parth_param)
@get(path="/{string_param:str}/{path_param:path}")
def upper_handler(string_param: str, path_param: Path) -> str:
return string_param + str(path_param)

with create_test_client([lower_handler, upper_handler]) as client:
response = client.get("/abc")
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_kwargs/test_path_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,11 @@ def test_method(
assert user_id
assert order_id

with create_test_client(test_method) as client:
with create_test_client(test_method, debug=True) as client:
response = client.get(
f"{params_dict['version']}/{params_dict['service_id']}/{params_dict['user_id']}/{params_dict['order_id']}"
)
print(response.text)
if should_raise:
assert response.status_code == HTTP_400_BAD_REQUEST, response.json()
else:
Expand Down

0 comments on commit 7dd6301

Please sign in to comment.