diff --git a/djantic/fields.py b/djantic/fields.py index 25608e1..ee6f0df 100644 --- a/djantic/fields.py +++ b/djantic/fields.py @@ -1,9 +1,9 @@ import logging +import typing from datetime import date, datetime, time, timedelta from decimal import Decimal from enum import Enum -from typing import Any, Dict, List, Union, Optional -import typing +from typing import Any, Dict, List, Optional, Union from uuid import UUID from django.utils.functional import Promise @@ -200,17 +200,16 @@ def ModelSchemaField(field: Any, schema_name: str) -> tuple: max_length=max_length, ) - field_is_optional = all([ - getattr(field, "null", None), - field.is_relation, - # A list that is null, is the empty list. So there is no need - # to make it nullable. - typing.get_origin(python_type) is not list - ]) + field_is_optional = all( + [ + getattr(field, "null", None), + field.is_relation, + # A list that is null, is the empty list. So there is no need + # to make it nullable. + typing.get_origin(python_type) is not list, + ] + ) if field_is_optional: python_type = Optional[python_type] - return ( - python_type, - field_info - ) + return (python_type, field_info) diff --git a/djantic/main.py b/djantic/main.py index f6ece98..992bf39 100644 --- a/djantic/main.py +++ b/djantic/main.py @@ -3,8 +3,7 @@ from enum import Enum from functools import reduce from itertools import chain -from typing import Any, Dict, List, Optional, no_type_check, Union -from typing_extensions import get_origin, get_args +from typing import Any, Dict, List, Optional, Union, no_type_check from django.core.serializers.json import DjangoJSONEncoder from django.db.models import Manager, Model @@ -13,8 +12,9 @@ from django.utils.encoding import force_str from django.utils.functional import Promise from pydantic import BaseModel, create_model -from pydantic.errors import PydanticUserError from pydantic._internal._model_construction import ModelMetaclass +from pydantic.errors import PydanticUserError +from typing_extensions import get_args, get_origin if sys.version_info >= (3, 10): from types import UnionType @@ -54,7 +54,6 @@ def __new__(mcs, name: str, bases: tuple, namespace: dict, **kwargs): and issubclass(base, ModelSchema) and base == ModelSchema ): - config = namespace["model_config"] include = config.get("include", None) exclude = config.get("exclude", None) @@ -103,7 +102,6 @@ def __new__(mcs, name: str, bases: tuple, namespace: dict, **kwargs): python_type = None pydantic_field = None if field_name in annotations and field_name in namespace: - python_type = annotations.pop(field_name) pydantic_field = namespace[field_name] if ( @@ -143,10 +141,10 @@ def __new__(mcs, name: str, bases: tuple, namespace: dict, **kwargs): def _is_optional_field(annotation) -> bool: args = get_args(annotation) return ( - (get_origin(annotation) is Union or get_origin(annotation) is UnionType) - and type(None) in args - and len(args) == 2 - and any(inspect.isclass(arg) and issubclass(arg, ModelSchema) for arg in args) + (get_origin(annotation) is Union or get_origin(annotation) is UnionType) + and type(None) in args + and len(args) == 2 + and any(inspect.isclass(arg) and issubclass(arg, ModelSchema) for arg in args) ) @@ -221,7 +219,9 @@ def dict(self) -> dict: non_none_type_annotation = next( arg for arg in get_args(annotation) if arg is not type(None) ) - data[key] = self._get_annotation_objects(value, non_none_type_annotation) + data[key] = self._get_annotation_objects( + value, non_none_type_annotation + ) elif inspect.isclass(annotation) and issubclass(annotation, ModelSchema): data[key] = self._get_annotation_objects(self.get(key), annotation) @@ -232,7 +232,6 @@ def dict(self) -> dict: class ModelSchema(BaseModel, metaclass=ModelSchemaMetaclass): - def __eq__(self, other: Any) -> bool: result = super().__eq__(other) if isinstance(result, bool): diff --git a/tests/test_fields.py b/tests/test_fields.py index f2440fc..82fadf2 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1,13 +1,22 @@ from typing import Optional import pytest -from pydantic import ConfigDict -from testapp.models import Configuration, Listing, Preference, Record, Searchable, User, NullableChar, NullableFK - +from packaging import version from pydantic import ( + ConfigDict, + ValidationError, ValidationInfo, field_validator, - ValidationError, +) +from testapp.models import ( + Configuration, + Listing, + NullableChar, + NullableFK, + Preference, + Record, + Searchable, + User, ) from djantic import ModelSchema @@ -43,25 +52,21 @@ class UserSchema(ModelSchema): @pytest.mark.django_db def test_context_for_field(): - def get_context(): - return {'check_title': lambda x: x.istitle()} + return {"check_title": lambda x: x.istitle()} class UserSchema(ModelSchema): - model_config = ConfigDict( - model=User, - revalidate_instances='always' - ) + model_config = ConfigDict(model=User, revalidate_instances="always") - @field_validator('first_name', mode="before", check_fields=False) + @field_validator("first_name", mode="before", check_fields=False) @classmethod def validate_first_name(cls, v: str, info: ValidationInfo): if not info.context: return v - check_title = info.context.get('check_title') + check_title = info.context.get("check_title") if check_title and not check_title(v): - raise ValueError('First name needs to be a title') + raise ValueError("First name needs to be a title") return v user = User.objects.create(first_name="hello", email="a@a.com") @@ -533,11 +538,11 @@ class ListingSchema(ModelSchema): @pytest.mark.django_db def test_nullable_fk(): class NullableCharSchema(ModelSchema): - model_config = ConfigDict(model=NullableChar, include='value') + model_config = ConfigDict(model=NullableChar, include="value") class NullableFKSchema(ModelSchema): nullable_char: Optional[NullableCharSchema] = None - model_config = ConfigDict(model=NullableFK, include='nullable_char') + model_config = ConfigDict(model=NullableFK, include="nullable_char") nullable_char = NullableChar(value="test") nullable_char.save() diff --git a/tests/test_files.py b/tests/test_files.py index 39b3d12..dd0dbb6 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -1,9 +1,9 @@ from tempfile import NamedTemporaryFile import pytest +from pydantic import ConfigDict from testapp.models import Attachment -from pydantic import ConfigDict from djantic import ModelSchema diff --git a/tests/test_main.py b/tests/test_main.py index f7d3e14..cf23b38 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,8 +1,8 @@ import pytest +from pydantic import ConfigDict from pydantic.errors import PydanticUserError from testapp.models import User -from pydantic import ConfigDict from djantic import ModelSchema diff --git a/tests/test_queries.py b/tests/test_queries.py index 761cff1..4504efb 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -1,9 +1,9 @@ from typing import List import pytest +from pydantic import ConfigDict from testapp.models import Bookmark, Message, Profile, Tagged, Thread, User -from pydantic import ConfigDict from djantic import ModelSchema @@ -28,7 +28,6 @@ class UserSchema(ModelSchema): @pytest.mark.django_db def test_get_instance_with_generic_foreign_key(): - bookmark = Bookmark.objects.create(url="https://www.djangoproject.com/") Tagged.objects.create(content_object=bookmark, slug="django") @@ -36,7 +35,6 @@ class TaggedSchema(ModelSchema): model_config = ConfigDict(model=Tagged) class BookmarkWithTaggedSchema(ModelSchema): - tags: List[TaggedSchema] model_config = ConfigDict(model=Bookmark) @@ -222,7 +220,6 @@ class ThreadWithMessageListSchema(ModelSchema): @pytest.mark.django_db def test_get_queryset_with_generic_foreign_key(): - bookmark = Bookmark.objects.create(url="https://github.com") bookmark.tags.create(slug="tag-1") bookmark.tags.create(slug="tag-2") diff --git a/tests/test_relations.py b/tests/test_relations.py index 15fee4f..38852a0 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional import pytest -from pydantic import Field +from pydantic import ConfigDict, Field from testapp.models import ( Article, Bookmark, @@ -17,7 +17,6 @@ User, ) -from pydantic import ConfigDict from djantic import ModelSchema @@ -656,7 +655,6 @@ class BookmarkSchema(ModelSchema): } class BookmarkWithTaggedSchema(ModelSchema): - tags: List[TaggedSchema] model_config = ConfigDict(model=Bookmark) @@ -724,7 +722,6 @@ class BookmarkWithTaggedSchema(ModelSchema): } class ItemSchema(ModelSchema): - tags: List[TaggedSchema] model_config = ConfigDict(model=Item) diff --git a/tests/test_schemas.py b/tests/test_schemas.py index e0c8692..4869dce 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -2,11 +2,9 @@ from typing import Optional import pytest -from pydantic import BaseModel, Field +from pydantic import AliasGenerator, BaseModel, ConfigDict, Field +from testapp.models import Profile, User -from testapp.models import User, Profile, Configuration - -from pydantic import ConfigDict, AliasGenerator from djantic import ModelSchema diff --git a/tests/testapp/manage.py b/tests/testapp/manage.py index af30516..084e8a2 100755 --- a/tests/testapp/manage.py +++ b/tests/testapp/manage.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Django's command-line utility for administrative tasks.""" + import os import sys diff --git a/tests/testapp/models.py b/tests/testapp/models.py index 9014598..b35fe05 100644 --- a/tests/testapp/models.py +++ b/tests/testapp/models.py @@ -1,15 +1,14 @@ -import uuid import os.path +import uuid from typing import Optional -from django.contrib.contenttypes.fields import GenericForeignKey +from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation from django.contrib.contenttypes.models import ContentType -from django.contrib.contenttypes.fields import GenericRelation -from django.db import models -from django.utils.text import slugify -from django.contrib.postgres.fields import JSONField, ArrayField +from django.contrib.postgres.fields import ArrayField, JSONField from django.contrib.postgres.indexes import GinIndex from django.contrib.postgres.search import SearchVectorField +from django.db import models +from django.utils.text import slugify from django.utils.translation import gettext_lazy as _ from .fields import ListField, NotNullRestrictedCharField @@ -281,7 +280,9 @@ class Case(ExtendedModel): class Listing(models.Model): items = ArrayField(models.TextField(), size=4) - content_type = models.ForeignKey(ContentType, on_delete=models.PROTECT, blank=True, null=True) + content_type = models.ForeignKey( + ContentType, on_delete=models.PROTECT, blank=True, null=True + ) class NullableChar(models.Model): @@ -289,4 +290,6 @@ class NullableChar(models.Model): class NullableFK(models.Model): - nullable_char = models.ForeignKey(NullableChar, null=True, blank=True, on_delete=models.CASCADE) + nullable_char = models.ForeignKey( + NullableChar, null=True, blank=True, on_delete=models.CASCADE + )