diff --git a/djantic/main.py b/djantic/main.py index b2987d2..2a9a733 100644 --- a/djantic/main.py +++ b/djantic/main.py @@ -1,4 +1,5 @@ import inspect +import sys from enum import Enum from functools import reduce from itertools import chain @@ -15,6 +16,10 @@ from pydantic.errors import PydanticUserError from pydantic._internal._model_construction import ModelMetaclass +if sys.version_info >= (3, 10): + from types import UnionType +else: + from typing import Union as UnionType from .fields import ModelSchemaField @@ -136,6 +141,16 @@ def __new__(mcs, name: str, bases: tuple, namespace: dict, **kwargs): return cls +def _is_optional_field(annotation) -> bool: + args = get_args(annotation) + return ( + (get_origin(annotation) is Union or get_origin(annotation) is UnionType) + and type(None) in args + and len(args) == 2 + and any(inspect.isclass(arg) and issubclass(arg, ModelSchema) for arg in args) + ) + + class ProxyGetterNestedObj: def __init__(self, obj: Any, schema_class): self._obj = obj @@ -199,7 +214,17 @@ def dict(self) -> dict: # Pick the underlying annotation annotation = get_args(annotation)[0] - if inspect.isclass(annotation) and issubclass(annotation, ModelSchema): + if _is_optional_field(annotation): + value = self.get(key) + if value is None: + data[key] = None + else: + non_none_type_annotation = next( + arg for arg in get_args(annotation) if arg is not type(None) + ) + data[key] = self._get_annotation_objects(value, non_none_type_annotation) + + elif inspect.isclass(annotation) and issubclass(annotation, ModelSchema): data[key] = self._get_annotation_objects(self.get(key), annotation) else: key = fieldinfo.alias if fieldinfo.alias else key diff --git a/tests/test_fields.py b/tests/test_fields.py index 8313909..41c17c6 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1,6 +1,8 @@ +from typing import Optional + import pytest from pydantic import ConfigDict -from testapp.models import Configuration, Listing, Preference, Record, Searchable, User +from testapp.models import Configuration, Listing, Preference, Record, Searchable, User, NullableChar, NullableFK from pydantic import ( ValidationInfo, @@ -408,3 +410,27 @@ class ListingSchema(ModelSchema): "id": None, "items": ["a", "b"], } + + +@pytest.mark.django_db +def test_nullable_fk(): + class NullableCharSchema(ModelSchema): + model_config = ConfigDict(model=NullableChar, include='value') + + class NullableFKSchema(ModelSchema): + nullable_char: Optional[NullableCharSchema] = None + model_config = ConfigDict(model=NullableFK, include='nullable_char') + + nullable_char = NullableChar(value="test") + nullable_char.save() + model = NullableFK(nullable_char=nullable_char) + assert NullableFKSchema.from_django(model).dict() == { + "nullable_char": { + "value": "test" + } + } + + model2 = NullableFK(nullable_char=None) + assert NullableFKSchema.from_django(model2).dict() == { + "nullable_char": None + } diff --git a/tests/testapp/models.py b/tests/testapp/models.py index a37283f..9014598 100644 --- a/tests/testapp/models.py +++ b/tests/testapp/models.py @@ -282,3 +282,11 @@ class Case(ExtendedModel): class Listing(models.Model): items = ArrayField(models.TextField(), size=4) content_type = models.ForeignKey(ContentType, on_delete=models.PROTECT, blank=True, null=True) + + +class NullableChar(models.Model): + value = models.CharField(max_length=256, null=True, blank=True) + + +class NullableFK(models.Model): + nullable_char = models.ForeignKey(NullableChar, null=True, blank=True, on_delete=models.CASCADE)