Skip to content

Commit

Permalink
chore(signature-model): remove pydantic and attrs signature models
Browse files Browse the repository at this point in the history
  • Loading branch information
Goldziher committed Jul 7, 2023
1 parent 6f38bbe commit 5f79d7e
Show file tree
Hide file tree
Showing 24 changed files with 184 additions and 1,220 deletions.
2 changes: 1 addition & 1 deletion litestar/_kwargs/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from inspect import isasyncgen, isgenerator
from typing import TYPE_CHECKING, Any

from litestar._signature.utils import get_signature_model
from litestar._signature import get_signature_model
from litestar.utils.compat import async_next

__all__ = ("Dependency", "create_dependency_batches", "map_dependencies_recursively", "resolve_dependency")
Expand Down
6 changes: 3 additions & 3 deletions litestar/_signature/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .models.base import SignatureModel
from .utils import create_signature_model, get_signature_model
from .model import SignatureModel
from .utils import get_signature_model

__all__ = ("create_signature_model", "SignatureModel", "get_signature_model")
__all__ = ("SignatureModel", "get_signature_model")
Original file line number Diff line number Diff line change
@@ -1,28 +1,46 @@
# ruff: noqa: UP006
from __future__ import annotations

import re
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, Sequence, Set, TypedDict, Union, cast

from msgspec import NODEFAULT, Meta, Struct, ValidationError, convert, defstruct
from msgspec.structs import asdict
from pydantic import ValidationError as PydanticValidationError
from typing_extensions import Annotated

from litestar.params import DependencyKwarg, KwargDefinition
from litestar._signature.utils import create_type_overrides, validate_signature_dependencies
from litestar.enums import ScopeType
from litestar.exceptions import InternalServerException, ValidationException
from litestar.params import DependencyKwarg, KwargDefinition, ParameterKwarg
from litestar.serialization import dec_hook
from litestar.typing import FieldDefinition # noqa: TCH
from litestar.utils import make_non_optional_union
from litestar.utils.dataclass import simple_asdict
from litestar.utils.typing import unwrap_union

from .base import SignatureModel

if TYPE_CHECKING:
from typing_extensions import NotRequired

from litestar.connection import ASGIConnection
from litestar.types import AnyCallable
from litestar.utils.signature import ParsedSignature

from .base import ErrorMessage

__all__ = ("MsgspecSignatureModel",)
__all__ = (
"ErrorMessage",
"SignatureModel",
)


class ErrorMessage(TypedDict):
# key may not be set in some cases, like when a query param is set but
# doesn't match the required length during `attrs` validation
# in this case, we don't show a key at all as it will be empty
key: NotRequired[str]
message: str
source: NotRequired[Literal["cookie", "body", "header", "query"]]


MSGSPEC_CONSTRAINT_FIELDS = (
"gt",
Expand All @@ -38,9 +56,68 @@
ERR_RE = re.compile(r"`\$\.(.+)`$")


class MsgspecSignatureModel(SignatureModel, Struct):
class SignatureModel(Struct):
"""Model that represents a function signature that uses a msgspec specific type or types."""

dependency_name_set: ClassVar[Set[str]]
return_annotation: ClassVar[Any]
fields: ClassVar[dict[str, FieldDefinition]]

@classmethod
def _create_exception(cls, connection: ASGIConnection, messages: list[ErrorMessage]) -> Exception:
"""Create an exception class - either a ValidationException or an InternalServerException, depending on whether
the failure is in client provided values or injected dependencies.
Args:
connection: An ASGI connection instance.
messages: A list of error messages.
Returns:
An Exception
"""
method = connection.method if hasattr(connection, "method") else ScopeType.WEBSOCKET # pyright: ignore
if client_errors := [
err_message
for err_message in messages
if ("key" in err_message and err_message["key"] not in cls.dependency_name_set) or "key" not in err_message
]:
return ValidationException(detail=f"Validation failed for {method} {connection.url}", extra=client_errors)
return InternalServerException()

@classmethod
def _build_error_message(cls, keys: Sequence[str], exc_msg: str, connection: ASGIConnection) -> ErrorMessage:
"""Build an error message.
Args:
keys: A list of keys.
exc_msg: A message.
connection: An ASGI connection instance.
Returns:
An ErrorMessage
"""

message: ErrorMessage = {"message": exc_msg.split(" - ")[0]}

if not keys:
return message

message["key"] = key = ".".join(keys)

if key in connection.query_params:
message["source"] = cast("Literal['cookie', 'body', 'header', 'query']", "query")

elif key in cls.fields and isinstance(cls.fields[key].kwarg_definition, ParameterKwarg):
if cast(ParameterKwarg, cls.fields[key].kwarg_definition).cookie:
source = "cookie"
elif cast(ParameterKwarg, cls.fields[key].kwarg_definition).header:
source = "header"
else:
source = "query"
message["source"] = cast("Literal['cookie', 'body', 'header', 'query']", source)

return message

@classmethod
def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, **kwargs: Any) -> dict[str, Any]:
"""Extract values from the connection instance and return a dict of parsed values.
Expand All @@ -62,13 +139,13 @@ def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, **kwarg
except PydanticValidationError as e:
for exc in e.errors():
keys = [str(loc) for loc in exc["loc"]]
message = super()._build_error_message(keys=keys, exc_msg=exc["msg"], connection=connection)
message = cls._build_error_message(keys=keys, exc_msg=exc["msg"], connection=connection)
messages.append(message)
raise cls._create_exception(messages=messages, connection=connection) from e
except ValidationError as e:
match = ERR_RE.search(str(e))
keys = [str(match.group(1)) if match else "n/a"]
message = super()._build_error_message(keys=keys, exc_msg=str(e), connection=connection)
message = cls._build_error_message(keys=keys, exc_msg=str(e), connection=connection)
messages.append(message)
raise cls._create_exception(messages=messages, connection=connection) from e

Expand All @@ -83,12 +160,20 @@ def to_dict(self) -> dict[str, Any]:
@classmethod
def create(
cls,
fn_name: str,
fn_module: str | None,
dependency_name_set: set[str],
fn: AnyCallable,
parsed_signature: ParsedSignature,
dependency_names: set[str],
type_overrides: dict[str, Any],
has_data_dto: bool = False,
) -> type[SignatureModel]:
fn_name = (
fn_name if (fn_name := getattr(fn, "__name__", "anonymous")) and fn_name != "<lambda>" else "anonymous"
)

dependency_names = validate_signature_dependencies(
dependency_name_set=dependency_name_set, fn_name=fn_name, parsed_signature=parsed_signature
)
type_overrides = create_type_overrides(parsed_signature, has_data_dto)

struct_fields: list[tuple[str, Any, Any]] = []

for field_definition in parsed_signature.parameters.values():
Expand Down Expand Up @@ -125,8 +210,8 @@ def create(
return defstruct( # type:ignore[return-value]
f"{fn_name}_signature_model",
struct_fields,
bases=(MsgspecSignatureModel,),
module=fn_module,
bases=(cls,),
module=getattr(fn, "__module__", None),
namespace={
"return_annotation": parsed_signature.return_type.annotation,
"dependency_name_set": dependency_names,
Expand Down
3 changes: 0 additions & 3 deletions litestar/_signature/models/__init__.py

This file was deleted.

Loading

0 comments on commit 5f79d7e

Please sign in to comment.