From defd780fc70534cea715fdcc0c708f1535f89e0e Mon Sep 17 00:00:00 2001 From: "Lance.Moe" Date: Mon, 18 Nov 2024 16:34:17 +0900 Subject: [PATCH] Make JSONField support type annotation and OpanAPI document generation (#1763) * Make JSONField support type annotation and OpanAPI document generation fix: codacy * fix: pass test and format code * test: add test case for json field pydantic type * fix: input * test: add testcase for init with pydantic type * chore(docs): add changelog * fix: codacy issue * chore(docs): add docs to code * refactor: use model_dump instead of dict method * fix: pass lint * chore: make style * fix: test python version is before 3.10 * fix: pytest warning * fix: codacy --- CHANGELOG.rst | 12 + CONTRIBUTORS.rst | 1 + docs/query.rst | 4 +- examples/postgres.py | 2 +- tests/contrib/test_pydantic.py | 181 +++++---------- tests/test_queryset_reuse.py | 5 +- tests/testmodels.py | 38 ++- tests/utils/test_describe_model.py | 42 ++++ tortoise/contrib/pydantic/creator.py | 270 +++++++++++----------- tortoise/contrib/pydantic/descriptions.py | 45 ++-- tortoise/fields/data.py | 65 ++++-- 11 files changed, 356 insertions(+), 309 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index bcbe2c4ed..2255f1ab5 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -8,6 +8,18 @@ Changelog 0.21 ==== +0.21.8 +------ +Fixed +^^^^^ +- TODO + +Added +^^^^^ +- JSONField adds optional generic support, and supports OpenAPI document generation by specifying `field_type` as a pydantic BaseModel (#1763) + + + 0.21.7 ------ Fixed diff --git a/CONTRIBUTORS.rst b/CONTRIBUTORS.rst index 7fcc91fe0..04a7d4988 100644 --- a/CONTRIBUTORS.rst +++ b/CONTRIBUTORS.rst @@ -61,6 +61,7 @@ Contributors * Andrea Magistà ``@vlakius`` * Daniel Szucs ``@Quasar6X`` * Rui Catarino ``@ruitcatarino`` +* Lance Moe ``@lancemoe`` Special Thanks ============== diff --git a/docs/query.rst b/docs/query.rst index b753bb4ce..6457ba593 100644 --- a/docs/query.rst +++ b/docs/query.rst @@ -243,7 +243,7 @@ In PostgreSQL and MYSQL, you can use the ``contains``, ``contained_by`` and ``fi .. code-block:: python3 class JSONModel: - data = fields.JSONField() + data = fields.JSONField[list]() await JSONModel.create(data=["text", 3, {"msg": "msg2"}]) obj = await JSONModel.filter(data__contains=[{"msg": "msg2"}]).first() @@ -257,7 +257,7 @@ In PostgreSQL and MYSQL, you can use the ``contains``, ``contained_by`` and ``fi .. code-block:: python3 class JSONModel: - data = fields.JSONField() + data = fields.JSONField[dict]() await JSONModel.create(data={"breed": "labrador", "owner": { diff --git a/examples/postgres.py b/examples/postgres.py index c98b3fd42..690df786a 100644 --- a/examples/postgres.py +++ b/examples/postgres.py @@ -8,7 +8,7 @@ class Report(Model): id = fields.IntField(primary_key=True) - content = fields.JSONField() + content = fields.JSONField[dict]() def __str__(self): return str(self.id) diff --git a/tests/contrib/test_pydantic.py b/tests/contrib/test_pydantic.py index feeec4f04..391b751fa 100644 --- a/tests/contrib/test_pydantic.py +++ b/tests/contrib/test_pydantic.py @@ -14,6 +14,7 @@ Team, Tournament, User, + json_pydantic_default, ) from tortoise.contrib import test from tortoise.contrib.pydantic import ( @@ -172,18 +173,14 @@ def test_event_schema(self): }, "reporter": { "anyOf": [ - { - "$ref": "#/$defs/Reporter_fgnv33_leaf" - }, + {"$ref": "#/$defs/Reporter_fgnv33_leaf"}, {"type": "null"}, ], "nullable": True, "title": "Reporter", }, "participants": { - "items": { - "$ref": "#/$defs/Team_ip4pg6_leaf" - }, + "items": {"$ref": "#/$defs/Team_ip4pg6_leaf"}, "title": "Participants", "type": "array", }, @@ -205,9 +202,7 @@ def test_event_schema(self): }, "address": { "anyOf": [ - { - "$ref": "#/$defs/Address_coqnj7_leaf" - }, + {"$ref": "#/$defs/Address_coqnj7_leaf"}, {"type": "null"}, ], "nullable": True, @@ -251,18 +246,14 @@ def test_eventlist_schema(self): }, "reporter": { "anyOf": [ - { - "$ref": "#/$defs/Reporter_fgnv33_leaf" - }, + {"$ref": "#/$defs/Reporter_fgnv33_leaf"}, {"type": "null"}, ], "nullable": True, "title": "Reporter", }, "participants": { - "items": { - "$ref": "#/$defs/Team_ip4pg6_leaf" - }, + "items": {"$ref": "#/$defs/Team_ip4pg6_leaf"}, "title": "Participants", "type": "array", }, @@ -291,9 +282,7 @@ def test_eventlist_schema(self): }, "address": { "anyOf": [ - { - "$ref": "#/$defs/Address_coqnj7_leaf" - }, + {"$ref": "#/$defs/Address_coqnj7_leaf"}, {"type": "null"}, ], "nullable": True, @@ -431,18 +420,14 @@ def test_address_schema(self): }, "reporter": { "anyOf": [ - { - "$ref": "#/$defs/Reporter_fgnv33_leaf" - }, + {"$ref": "#/$defs/Reporter_fgnv33_leaf"}, {"type": "null"}, ], "nullable": True, "title": "Reporter", }, "participants": { - "items": { - "$ref": "#/$defs/Team_ip4pg6_leaf" - }, + "items": {"$ref": "#/$defs/Team_ip4pg6_leaf"}, "title": "Participants", "type": "array", }, @@ -591,18 +576,14 @@ def test_tournament_schema(self): "name": {"description": "The name", "title": "Name", "type": "string"}, "reporter": { "anyOf": [ - { - "$ref": "#/$defs/Reporter_fgnv33_leaf" - }, + {"$ref": "#/$defs/Reporter_fgnv33_leaf"}, {"type": "null"}, ], "nullable": True, "title": "Reporter", }, "participants": { - "items": { - "$ref": "#/$defs/Team_ip4pg6_leaf" - }, + "items": {"$ref": "#/$defs/Team_ip4pg6_leaf"}, "title": "Participants", "type": "array", }, @@ -631,9 +612,7 @@ def test_tournament_schema(self): }, "address": { "anyOf": [ - { - "$ref": "#/$defs/Address_coqnj7_leaf" - }, + {"$ref": "#/$defs/Address_coqnj7_leaf"}, {"type": "null"}, ], "nullable": True, @@ -765,9 +744,7 @@ def test_team_schema(self): }, "reporter": { "anyOf": [ - { - "$ref": "#/$defs/Reporter_fgnv33_leaf" - }, + {"$ref": "#/$defs/Reporter_fgnv33_leaf"}, {"type": "null"}, ], "nullable": True, @@ -798,9 +775,7 @@ def test_team_schema(self): }, "address": { "anyOf": [ - { - "$ref": "#/$defs/Address_coqnj7_leaf" - }, + {"$ref": "#/$defs/Address_coqnj7_leaf"}, {"type": "null"}, ], "nullable": True, @@ -1211,6 +1186,7 @@ async def test_json_field(self): "data_null": None, "data_default": {"a": 1}, "data_validate": None, + "data_pydantic": json_pydantic_default.model_dump(), }, ) ret1 = creator.model_validate(json_field_1_get).model_dump() @@ -1222,6 +1198,7 @@ async def test_json_field(self): "data_null": None, "data_default": {"a": 1}, "data_validate": None, + "data_pydantic": json_pydantic_default.model_dump(), }, ) @@ -1495,7 +1472,9 @@ async def asyncSetUp(self) -> None: self.maxDiff = None async def test_computed_field(self): - employee_pyd = await self.Employee_Pydantic.from_tortoise_orm(await Employee.get(name="Some Employee")) + employee_pyd = await self.Employee_Pydantic.from_tortoise_orm( + await Employee.get(name="Some Employee") + ) employee_serialised = employee_pyd.model_dump() self.assertEqual(employee_serialised.get("name_length"), self.employee.name_length()) @@ -1511,49 +1490,38 @@ async def test_computed_field_schema(self): "maximum": 2147483647, "minimum": -2147483648, "title": "Id", - "type": "integer" - }, - "name": { - "maxLength": 50, - "title": "Name", - "type": "string" + "type": "integer", }, + "name": {"maxLength": 50, "title": "Name", "type": "string"}, "manager_id": { "anyOf": [ { "maximum": 2147483647, "minimum": -2147483648, - "type": "integer" + "type": "integer", }, - { - "type": "null" - } + {"type": "null"}, ], "default": None, "nullable": True, - "title": "Manager Id" + "title": "Manager Id", }, "name_length": { "description": "", "readOnly": True, "title": "Name Length", - "type": "integer" + "type": "integer", }, "team_size": { "description": "Computes team size.

Note that this function needs to be annotated with a return type so that pydantic can
generate a valid schema.

Note that the pydantic serializer can't call async methods, but the tortoise helpers
pre-fetch relational data, so that it is available before serialization. So we don't
need to await the relation. We do however have to protect against the case where no
prefetching was done, hence catching and handling the
``tortoise.exceptions.NoValuesFetched`` exception.", "readOnly": True, "title": "Team Size", - "type": "integer" - } + "type": "integer", + }, }, - "required": [ - "id", - "name", - "name_length", - "team_size" - ], + "required": ["id", "name", "name_length", "team_size"], "title": "Employee", - "type": "object" + "type": "object", }, "Employee_6tkbjb_leaf": { "additionalProperties": False, @@ -1562,54 +1530,44 @@ async def test_computed_field_schema(self): "maximum": 2147483647, "minimum": -2147483648, "title": "Id", - "type": "integer" - }, - "name": { - "maxLength": 50, - "title": "Name", - "type": "string" + "type": "integer", }, + "name": {"maxLength": 50, "title": "Name", "type": "string"}, "talks_to": { - "items": { - "$ref": "#/$defs/Employee_fj2ly4_leaf" - }, + "items": {"$ref": "#/$defs/Employee_fj2ly4_leaf"}, "title": "Talks To", - "type": "array" + "type": "array", }, "manager_id": { "anyOf": [ { "maximum": 2147483647, "minimum": -2147483648, - "type": "integer" + "type": "integer", }, - { - "type": "null" - } + {"type": "null"}, ], "default": None, "nullable": True, - "title": "Manager Id" + "title": "Manager Id", }, "team_members": { - "items": { - "$ref": "#/$defs/Employee_fj2ly4_leaf" - }, + "items": {"$ref": "#/$defs/Employee_fj2ly4_leaf"}, "title": "Team Members", - "type": "array" + "type": "array", }, "name_length": { "description": "", "readOnly": True, "title": "Name Length", - "type": "integer" + "type": "integer", }, "team_size": { "description": "Computes team size.

Note that this function needs to be annotated with a return type so that pydantic can
generate a valid schema.

Note that the pydantic serializer can't call async methods, but the tortoise helpers
pre-fetch relational data, so that it is available before serialization. So we don't
need to await the relation. We do however have to protect against the case where no
prefetching was done, hence catching and handling the
``tortoise.exceptions.NoValuesFetched`` exception.", "readOnly": True, "title": "Team Size", - "type": "integer" - } + "type": "integer", + }, }, "required": [ "id", @@ -1617,11 +1575,11 @@ async def test_computed_field_schema(self): "talks_to", "team_members", "name_length", - "team_size" + "team_size", ], "title": "Employee", - "type": "object" - } + "type": "object", + }, }, "additionalProperties": False, "properties": { @@ -1629,66 +1587,45 @@ async def test_computed_field_schema(self): "maximum": 2147483647, "minimum": -2147483648, "title": "Id", - "type": "integer" - }, - "name": { - "maxLength": 50, - "title": "Name", - "type": "string" + "type": "integer", }, + "name": {"maxLength": 50, "title": "Name", "type": "string"}, "talks_to": { - "items": { - "$ref": "#/$defs/Employee_6tkbjb_leaf" - }, + "items": {"$ref": "#/$defs/Employee_6tkbjb_leaf"}, "title": "Talks To", - "type": "array" + "type": "array", }, "manager_id": { "anyOf": [ - { - "maximum": 2147483647, - "minimum": -2147483648, - "type": "integer" - }, - { - "type": "null" - } + {"maximum": 2147483647, "minimum": -2147483648, "type": "integer"}, + {"type": "null"}, ], "default": None, "nullable": True, - "title": "Manager Id" + "title": "Manager Id", }, "team_members": { - "items": { - "$ref": "#/$defs/Employee_6tkbjb_leaf" - }, + "items": {"$ref": "#/$defs/Employee_6tkbjb_leaf"}, "title": "Team Members", - "type": "array" + "type": "array", }, "name_length": { "description": "", "readOnly": True, "title": "Name Length", - "type": "integer" + "type": "integer", }, "team_size": { "description": "Computes team size.

Note that this function needs to be annotated with a return type so that pydantic can
generate a valid schema.

Note that the pydantic serializer can't call async methods, but the tortoise helpers
pre-fetch relational data, so that it is available before serialization. So we don't
need to await the relation. We do however have to protect against the case where no
prefetching was done, hence catching and handling the
``tortoise.exceptions.NoValuesFetched`` exception.", "readOnly": True, "title": "Team Size", - "type": "integer" - } + "type": "integer", + }, }, - "required": [ - "id", - "name", - "talks_to", - "team_members", - "name_length", - "team_size" - ], + "required": ["id", "name", "talks_to", "team_members", "name_length", "team_size"], "title": "Employee", - "type": "object" - } + "type": "object", + }, ) diff --git a/tests/test_queryset_reuse.py b/tests/test_queryset_reuse.py index 1e043be80..7d7adf95e 100644 --- a/tests/test_queryset_reuse.py +++ b/tests/test_queryset_reuse.py @@ -1,7 +1,4 @@ -from tests.testmodels import ( - Event, - Tournament, -) +from tests.testmodels import Event, Tournament from tortoise.contrib import test from tortoise.contrib.test.condition import NotEQ from tortoise.expressions import F diff --git a/tests/testmodels.py b/tests/testmodels.py index cc809c29a..6386e5e50 100644 --- a/tests/testmodels.py +++ b/tests/testmodels.py @@ -12,7 +12,7 @@ from typing import List, Union import pytz -from pydantic import ConfigDict +from pydantic import BaseModel, ConfigDict from tortoise import fields from tortoise.exceptions import ValidationError @@ -34,6 +34,15 @@ def generate_token(): return binascii.hexlify(os.urandom(16)).decode("ascii") +class TestSchemaForJSONField(BaseModel): + foo: int + bar: str + __test__ = False + + +json_pydantic_default = TestSchemaForJSONField(foo=1, bar="baz") + + class Author(Model): name = fields.CharField(max_length=255) @@ -286,21 +295,30 @@ class FloatFields(Model): floatnum_null = fields.FloatField(null=True) +def raise_if_not_dict_or_list(value: Union[dict, list]): + if not isinstance(value, (dict, list)): + raise ValidationError("Value must be a dict or list.") + + class JSONFields(Model): """ This model contains many JSON blobs """ - @staticmethod - def dict_or_list(value: Union[dict, list]): - if not isinstance(value, (dict, list)): - raise ValidationError("Value must be a dict or list.") - id = fields.IntField(primary_key=True) - data = fields.JSONField() - data_null = fields.JSONField(null=True) - data_default = fields.JSONField(default={"a": 1}) - data_validate = fields.JSONField(null=True, validators=[lambda v: JSONFields.dict_or_list(v)]) + data = fields.JSONField() # type: ignore # Test cases where generics are not provided + data_null = fields.JSONField[Union[dict, list]](null=True) + data_default = fields.JSONField[dict](default={"a": 1}) + + # From Python 3.10 onwards, validator can be defined with staticmethod + data_validate = fields.JSONField[Union[dict, list]]( + null=True, validators=[raise_if_not_dict_or_list] + ) + + # Test cases where generics are provided and the type is a pydantic base model + data_pydantic = fields.JSONField[TestSchemaForJSONField]( + default=json_pydantic_default, field_type=TestSchemaForJSONField + ) class UUIDFields(Model): diff --git a/tests/utils/test_describe_model.py b/tests/utils/test_describe_model.py index 3b9b1a78d..998d6e794 100644 --- a/tests/utils/test_describe_model.py +++ b/tests/utils/test_describe_model.py @@ -9,11 +9,13 @@ SourceFields, StraightFields, Team, + TestSchemaForJSONField, Tournament, UUIDFkRelatedModel, UUIDFkRelatedNullModel, UUIDM2MRelatedModel, UUIDPkModel, + json_pydantic_default, ) from tortoise import Tortoise, fields from tortoise.contrib import test @@ -1392,6 +1394,26 @@ def test_describe_model_json(self): "docstring": None, "constraints": {}, }, + { + "name": "data_pydantic", + "field_type": "JSONField", + "db_column": "data_pydantic", + "db_field_types": { + "": "JSON", + "mssql": "NVARCHAR(MAX)", + "oracle": "NCLOB", + "postgres": "JSONB", + }, + "python_type": "tests.testmodels.TestSchemaForJSONField", + "generated": False, + "nullable": False, + "unique": False, + "indexed": False, + "default": "foo=1 bar='baz'", + "description": None, + "docstring": None, + "constraints": {}, + }, ], "fk_fields": [], "backward_fk_fields": [], @@ -1511,6 +1533,26 @@ def test_describe_model_json_native(self): "docstring": None, "constraints": {}, }, + { + "name": "data_pydantic", + "field_type": fields.JSONField, + "db_column": "data_pydantic", + "db_field_types": { + "": "JSON", + "mssql": "NVARCHAR(MAX)", + "oracle": "NCLOB", + "postgres": "JSONB", + }, + "python_type": TestSchemaForJSONField, + "generated": False, + "nullable": False, + "unique": False, + "indexed": False, + "default": json_pydantic_default, + "description": None, + "docstring": None, + "constraints": {}, + }, ], "fk_fields": [], "backward_fk_fields": [], diff --git a/tortoise/contrib/pydantic/creator.py b/tortoise/contrib/pydantic/creator.py index 8d0e3c828..ca8335be6 100644 --- a/tortoise/contrib/pydantic/creator.py +++ b/tortoise/contrib/pydantic/creator.py @@ -1,20 +1,38 @@ import inspect from base64 import b32encode from copy import copy -from typing import MutableMapping - from hashlib import sha3_224 -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union - -from pydantic import ConfigDict, computed_field, create_model +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + MutableMapping, + Optional, + Tuple, + Type, + Union, +) + +from pydantic import ConfigDict from pydantic import Field as PydanticField - -from tortoise import ForeignKeyFieldInstance, BackwardFKRelation, ManyToManyFieldInstance, OneToOneFieldInstance, \ - BackwardOneToOneRelation +from pydantic import computed_field, create_model + +from tortoise import ( + BackwardFKRelation, + BackwardOneToOneRelation, + ForeignKeyFieldInstance, + ManyToManyFieldInstance, + OneToOneFieldInstance, +) from tortoise.contrib.pydantic.base import PydanticListModel, PydanticModel +from tortoise.contrib.pydantic.descriptions import ( + ComputedFieldDescription, + ModelDescription, + PydanticMetaData, +) from tortoise.contrib.pydantic.utils import get_annotations -from tortoise.fields import JSONField, Field -from tortoise.contrib.pydantic.descriptions import ModelDescription, PydanticMetaData, ComputedFieldDescription +from tortoise.fields import Field, JSONField if TYPE_CHECKING: # pragma: nocoverage from tortoise.models import Model @@ -41,15 +59,15 @@ def _cleandoc(obj: Any) -> str: def _pydantic_recursion_protector( - cls: "Type[Model]", - *, - stack: Tuple, - exclude: Tuple[str, ...] = (), - include: Tuple[str, ...] = (), - computed: Tuple[str, ...] = (), - name=None, - allow_cycles: bool = False, - sort_alphabetically: Optional[bool] = None, + cls: "Type[Model]", + *, + stack: Tuple, + exclude: Tuple[str, ...] = (), + include: Tuple[str, ...] = (), + computed: Tuple[str, ...] = (), + name=None, + allow_cycles: bool = False, + sort_alphabetically: Optional[bool] = None, ) -> Optional[Type[PydanticModel]]: """ It is an inner function to protect pydantic model creator against cyclic recursion @@ -116,7 +134,9 @@ def sort_alphabetically(self) -> None: def sort_definition_order(self, cls: "Type[Model]", computed: Tuple[str, ...]) -> None: self._field_map = { - k: self._field_map[k] for k in tuple(cls._meta.fields_map.keys()) + computed if k in self._field_map + k: self._field_map[k] + for k in tuple(cls._meta.fields_map.keys()) + computed + if k in self._field_map } def field_map_update(self, fields: List[Field], meta: PydanticMetaData) -> None: @@ -128,7 +148,11 @@ def field_map_update(self, fields: List[Field], meta: PydanticMetaData) -> None: # Remove raw fields if isinstance(field, ForeignKeyFieldInstance): raw_field = field.source_field - if raw_field is not None and meta.exclude_raw_fields and raw_field != self.pk_raw_field: + if ( + raw_field is not None + and meta.exclude_raw_fields + and raw_field != self.pk_raw_field + ): self.pop(raw_field, None) self[name] = field @@ -146,14 +170,14 @@ def computed_field_map_update(self, computed: Tuple[str, ...], cls: "Type[Model] def pydantic_queryset_creator( - cls: "Type[Model]", - *, - name=None, - exclude: Tuple[str, ...] = (), - include: Tuple[str, ...] = (), - computed: Tuple[str, ...] = (), - allow_cycles: Optional[bool] = None, - sort_alphabetically: Optional[bool] = None, + cls: "Type[Model]", + *, + name=None, + exclude: Tuple[str, ...] = (), + include: Tuple[str, ...] = (), + computed: Tuple[str, ...] = (), + allow_cycles: Optional[bool] = None, + sort_alphabetically: Optional[bool] = None, ) -> Type[PydanticListModel]: """ Function to build a `Pydantic Model `__ list off Tortoise Model. @@ -206,34 +230,36 @@ def pydantic_queryset_creator( class PydanticModelCreator: def __init__( - self, - cls: "Type[Model]", - name: Optional[str] = None, - exclude: Optional[Tuple[str, ...]] = None, - include: Optional[Tuple[str, ...]] = None, - computed: Optional[Tuple[str, ...]] = None, - optional: Optional[Tuple[str, ...]] = None, - allow_cycles: Optional[bool] = None, - sort_alphabetically: Optional[bool] = None, - exclude_readonly: bool = False, - meta_override: Optional[Type] = None, - model_config: Optional[ConfigDict] = None, - validators: Optional[Dict[str, Any]] = None, - module: str = __name__, - _stack: tuple = (), - _as_submodel: bool = False + self, + cls: "Type[Model]", + name: Optional[str] = None, + exclude: Optional[Tuple[str, ...]] = None, + include: Optional[Tuple[str, ...]] = None, + computed: Optional[Tuple[str, ...]] = None, + optional: Optional[Tuple[str, ...]] = None, + allow_cycles: Optional[bool] = None, + sort_alphabetically: Optional[bool] = None, + exclude_readonly: bool = False, + meta_override: Optional[Type] = None, + model_config: Optional[ConfigDict] = None, + validators: Optional[Dict[str, Any]] = None, + module: str = __name__, + _stack: tuple = (), + _as_submodel: bool = False, ) -> None: self._cls: "Type[Model]" = cls - self._stack: Tuple[Tuple["Type[Model]", str, int], ...] = _stack # ((Type[Model], field_name, max_recursion),) + self._stack: Tuple[Tuple["Type[Model]", str, int], ...] = ( + _stack # ((Type[Model], field_name, max_recursion),) + ) self._is_default: bool = ( - exclude is None - and include is None - and computed is None - and optional is None - and sort_alphabetically is None - and allow_cycles is None - and meta_override is None - and not exclude_readonly + exclude is None + and include is None + and computed is None + and optional is None + and sort_alphabetically is None + and allow_cycles is None + and meta_override is None + and not exclude_readonly ) if exclude is None: exclude = () @@ -295,7 +321,9 @@ def _hash(self): f"{self._fqname};{self._properties.keys()};{self._relational_fields_index};{self._optional};" f"{self.meta.allow_cycles}" ) - self.__hash = b32encode(sha3_224(hashval.encode("utf-8")).digest()).decode("utf-8").lower()[:6] + self.__hash = ( + b32encode(sha3_224(hashval.encode("utf-8")).digest()).decode("utf-8").lower()[:6] + ) return self.__hash def get_name(self) -> Tuple[str, str]: @@ -305,16 +333,8 @@ def get_name(self) -> Tuple[str, str]: # When called later, include is explicitly set, so fence passes. if self.given_name is not None: return self.given_name, self.given_name - name = ( - f"{self._fqname}:{self._hash}" - if not self._is_default - else self._fqname - ) - name = ( - f"{name}:leaf" - if self._as_submodel - else name - ) + name = f"{self._fqname}:{self._hash}" if not self._is_default else self._fqname + name = f"{name}:leaf" if self._as_submodel else name return name, self._cls.__name__ def _initialize_pconfig(self) -> ConfigDict: @@ -324,7 +344,7 @@ def _initialize_pconfig(self) -> ConfigDict: if "title" not in pconfig: pconfig["title"] = self._title if "extra" not in pconfig: - pconfig["extra"] = 'forbid' + pconfig["extra"] = "forbid" return pconfig def _initialize_field_map(self) -> FieldMap: @@ -338,15 +358,15 @@ def _construct_field_map(self) -> None: self._field_map.field_map_update(fields=self._model_description.data_fields, meta=self.meta) if not self._exclude_read_only: for fields in ( - self._model_description.fk_fields, - self._model_description.o2o_fields, - self._model_description.m2m_fields + self._model_description.fk_fields, + self._model_description.o2o_fields, + self._model_description.m2m_fields, ): self._field_map.field_map_update(fields, self.meta) if self.meta.backward_relations: for fields in ( - self._model_description.backward_fk_fields, - self._model_description.backward_o2o_fields + self._model_description.backward_fk_fields, + self._model_description.backward_o2o_fields, ): self._field_map.field_map_update(fields, self.meta) self._field_map.computed_field_map_update(self.meta.computed, self._cls) @@ -386,9 +406,9 @@ def create_pydantic_model(self) -> Type[PydanticModel]: return model def _process_field( - self, - field_name: str, - field: Union[Field, ComputedFieldDescription], + self, + field_name: str, + field: Union[Field, ComputedFieldDescription], ) -> None: json_schema_extra: Dict[str, Any] = {} fconfig: Dict[str, Any] = { @@ -405,18 +425,16 @@ def _process_field( description = _br_it(field.docstring or field.description or "") if description: fconfig["description"] = description - if ( - field_name in self._optional - or (field.default is not None and not callable(field.default)) + if field_name in self._optional or ( + field.default is not None and not callable(field.default) ): - self._properties[field_name] = (field_property, PydanticField(default=field.default, **fconfig)) + self._properties[field_name] = ( + field_property, + PydanticField(default=field.default, **fconfig), + ) else: - if ( - ( - json_schema_extra.get("nullable") - and not is_to_one_relation - ) - or (self._exclude_read_only and json_schema_extra.get("readOnly")) + if (json_schema_extra.get("nullable") and not is_to_one_relation) or ( + self._exclude_read_only and json_schema_extra.get("readOnly") ): # see: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields fconfig["default"] = None @@ -432,19 +450,14 @@ def _process_field( self._properties[field_name] = field_property def _process_normal_field( - self, - field_name: str, - field: Field, - json_schema_extra: Dict[str, Any], - fconfig: Dict[str, Any], + self, + field_name: str, + field: Field, + json_schema_extra: Dict[str, Any], + fconfig: Dict[str, Any], ) -> Tuple[Optional[Any], bool]: if isinstance( - field, - ( - ForeignKeyFieldInstance, - OneToOneFieldInstance, - BackwardOneToOneRelation - ) + field, (ForeignKeyFieldInstance, OneToOneFieldInstance, BackwardOneToOneRelation) ): return self._process_single_field_relation(field_name, field, json_schema_extra), True elif isinstance(field, (BackwardFKRelation, ManyToManyFieldInstance)): @@ -454,14 +467,10 @@ def _process_normal_field( return self._process_data_field(field_name, field, json_schema_extra, fconfig), False def _process_single_field_relation( - self, - field_name: str, - field: Union[ - ForeignKeyFieldInstance, - OneToOneFieldInstance, - BackwardOneToOneRelation - ], - json_schema_extra: Dict[str, Any], + self, + field_name: str, + field: Union[ForeignKeyFieldInstance, OneToOneFieldInstance, BackwardOneToOneRelation], + json_schema_extra: Dict[str, Any], ) -> Optional[Type[PydanticModel]]: python_type = getattr(field, "related_model", field.field_type) model: Optional[Type[PydanticModel]] = self._get_submodel(python_type, field_name) @@ -476,9 +485,9 @@ def _process_single_field_relation( return None def _process_many_field_relation( - self, - field_name: str, - field: Union[BackwardFKRelation, ManyToManyFieldInstance], + self, + field_name: str, + field: Union[BackwardFKRelation, ManyToManyFieldInstance], ) -> Optional[Type[List[Type[PydanticModel]]]]: python_type = field.related_model model = self._get_submodel(python_type, field_name) @@ -488,11 +497,11 @@ def _process_many_field_relation( return None def _process_data_field( - self, - field_name: str, - field: Field, - json_schema_extra: Dict[str, Any], - fconfig: Dict[str, Any], + self, + field_name: str, + field: Field, + json_schema_extra: Dict[str, Any], + fconfig: Dict[str, Any], ) -> Optional[Any]: annotation = self._annotations.get(field_name, None) constraints = copy(field.constraints) @@ -511,8 +520,8 @@ def _process_data_field( return None def _process_computed_field( - self, - field: ComputedFieldDescription, + self, + field: ComputedFieldDescription, ) -> Optional[Any]: func = field.function annotation = get_annotations(self._cls, func).get("return", None) @@ -523,7 +532,9 @@ def _process_computed_field( return ret return None - def _get_submodel(self, _model: Optional["Type[Model]"], field_name: str) -> Optional[Type[PydanticModel]]: + def _get_submodel( + self, _model: Optional["Type[Model]"], field_name: str + ) -> Optional[Type[PydanticModel]]: """Get Pydantic model for the submodel""" if _model: @@ -536,6 +547,7 @@ def get_fields_to_carry_on(field_tuple: Tuple[str, ...]) -> Tuple[str, ...]: return tuple( str(v[prefix_len:]) for v in field_tuple if v.startswith(field_name + ".") ) + pmodel = _pydantic_recursion_protector( _model, exclude=get_fields_to_carry_on(self.meta.exclude), @@ -556,20 +568,20 @@ def get_fields_to_carry_on(field_tuple: Tuple[str, ...]) -> Tuple[str, ...]: def pydantic_model_creator( - cls: "Type[Model]", - *, - name=None, - exclude: Optional[Tuple[str, ...]] = None, - include: Optional[Tuple[str, ...]] = None, - computed: Optional[Tuple[str, ...]] = None, - optional: Optional[Tuple[str, ...]] = None, - allow_cycles: Optional[bool] = None, - sort_alphabetically: Optional[bool] = None, - exclude_readonly: bool = False, - meta_override: Optional[Type] = None, - model_config: Optional[ConfigDict] = None, - validators: Optional[Dict[str, Any]] = None, - module: str = __name__, + cls: "Type[Model]", + *, + name=None, + exclude: Optional[Tuple[str, ...]] = None, + include: Optional[Tuple[str, ...]] = None, + computed: Optional[Tuple[str, ...]] = None, + optional: Optional[Tuple[str, ...]] = None, + allow_cycles: Optional[bool] = None, + sort_alphabetically: Optional[bool] = None, + exclude_readonly: bool = False, + meta_override: Optional[Type] = None, + model_config: Optional[ConfigDict] = None, + validators: Optional[Dict[str, Any]] = None, + module: str = __name__, ) -> Type[PydanticModel]: """ Function to build `Pydantic Model `__ off Tortoise Model. @@ -615,6 +627,6 @@ def pydantic_model_creator( meta_override=meta_override, model_config=model_config, validators=validators, - module=module + module=module, ) return pmc.create_pydantic_model() diff --git a/tortoise/contrib/pydantic/descriptions.py b/tortoise/contrib/pydantic/descriptions.py index 8d3770b8a..6675a94e6 100644 --- a/tortoise/contrib/pydantic/descriptions.py +++ b/tortoise/contrib/pydantic/descriptions.py @@ -1,6 +1,6 @@ import dataclasses import sys -from typing import Type, Optional, Any, TYPE_CHECKING, List, Tuple, Callable +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Type if sys.version_info >= (3, 11): from typing import Self @@ -32,7 +32,8 @@ def from_model(cls, model: Type["Model"]) -> Self: data_fields=[ field for name, field in model._meta.fields_map.items() - if name != model._meta.pk_attr and name in (model._meta.fields - model._meta.fetch_fields) + if name != model._meta.pk_attr + and name in (model._meta.fields - model._meta.fetch_fields) ], fk_fields=[ field @@ -107,13 +108,16 @@ def from_pydantic_meta(cls, old_pydantic_meta: Any) -> Self: def get_param_from_pydantic_meta(attr: str, default: Any) -> Any: return getattr(old_pydantic_meta, attr, default) + include = tuple(get_param_from_pydantic_meta("include", default_meta.include)) exclude = tuple(get_param_from_pydantic_meta("exclude", default_meta.exclude)) computed = tuple(get_param_from_pydantic_meta("computed", default_meta.computed)) backward_relations = bool( get_param_from_pydantic_meta("backward_relations_raw", default_meta.backward_relations) ) - max_recursion = int(get_param_from_pydantic_meta("max_recursion", default_meta.max_recursion)) + max_recursion = int( + get_param_from_pydantic_meta("max_recursion", default_meta.max_recursion) + ) allow_cycles = bool(get_param_from_pydantic_meta("allow_cycles", default_meta.allow_cycles)) exclude_raw_fields = bool( get_param_from_pydantic_meta("exclude_raw_fields", default_meta.exclude_raw_fields) @@ -131,14 +135,11 @@ def get_param_from_pydantic_meta(attr: str, default: Any) -> Any: allow_cycles=allow_cycles, exclude_raw_fields=exclude_raw_fields, sort_alphabetically=sort_alphabetically, - model_config=model_config + model_config=model_config, ) return pmd - def construct_pydantic_meta( - self, - meta_override: Type - ) -> "PydanticMetaData": + def construct_pydantic_meta(self, meta_override: Type) -> "PydanticMetaData": def get_param_from_meta_override(attr: str) -> Any: return getattr(meta_override, attr, getattr(self, attr)) @@ -163,29 +164,23 @@ def get_param_from_meta_override(attr: str) -> Any: max_recursion=max_recursion, exclude_raw_fields=exclude_raw_fields, sort_alphabetically=sort_alphabetically, - allow_cycles=allow_cycles + allow_cycles=allow_cycles, ) return pmd def finalize_meta( - self, - exclude: Tuple[str, ...] = (), - include: Tuple[str, ...] = (), - computed: Tuple[str, ...] = (), - allow_cycles: Optional[bool] = None, - sort_alphabetically: Optional[bool] = None, - model_config: Optional[ConfigDict] = None, + self, + exclude: Tuple[str, ...] = (), + include: Tuple[str, ...] = (), + computed: Tuple[str, ...] = (), + allow_cycles: Optional[bool] = None, + sort_alphabetically: Optional[bool] = None, + model_config: Optional[ConfigDict] = None, ) -> "PydanticMetaData": _sort_fields: bool = ( - self.sort_alphabetically - if sort_alphabetically is None - else sort_alphabetically - ) - _allow_cycles: bool = ( - self.allow_cycles - if allow_cycles is None - else allow_cycles + self.sort_alphabetically if sort_alphabetically is None else sort_alphabetically ) + _allow_cycles: bool = self.allow_cycles if allow_cycles is None else allow_cycles include = tuple(include) + self.include exclude = tuple(exclude) + self.exclude @@ -206,5 +201,5 @@ def finalize_meta( exclude_raw_fields=self.exclude_raw_fields, sort_alphabetically=_sort_fields, allow_cycles=_allow_cycles, - model_config=_model_config + model_config=_model_config, ) diff --git a/tortoise/fields/data.py b/tortoise/fields/data.py index 10846fb29..aa0302b6b 100644 --- a/tortoise/fields/data.py +++ b/tortoise/fields/data.py @@ -46,6 +46,8 @@ "UUIDField", ) +T = TypeVar("T") + # Doing this we can replace json dumps/loads with different implementations JsonDumpsFunc = Callable[[Any], str] JsonLoadsFunc = Callable[[Union[str, bytes]], Any] @@ -517,12 +519,14 @@ class _db_mysql: SQL_TYPE = "DOUBLE" -class JSONField(Field[Union[dict, list]], dict, list): # type: ignore +class JSONField(Field[T], dict, list): # type: ignore """ JSON field. This field can store dictionaries or lists of any JSON-compliant structure. + You can use generics to make static checking more friendly. Example: ``JSONField[dict[str, str]]`` + You can specify your own custom JSON encoder/decoder, leaving at the default should work well. If you have ``orjson`` installed, we default to using that, else the default ``json`` module will be used. @@ -532,6 +536,11 @@ class JSONField(Field[Union[dict, list]], dict, list): # type: ignore ``decoder``: The custom JSON decoder. + If you want to use Pydantic model as the field type for generating a better OpenAPI documentation, you can use ``field_type`` to specify the type of the field. + + ``field_type``: + The Pydantic model class. + """ SQL_TYPE = "JSON" @@ -555,29 +564,53 @@ def __init__( super().__init__(**kwargs) self.encoder = encoder self.decoder = decoder + if field_type := kwargs.get("field_type", None): + self.field_type = field_type def to_db_value( - self, value: Optional[Union[dict, list, str, bytes]], instance: "Union[Type[Model], Model]" + self, + value: Optional[Union[T, dict, list, str, bytes]], + instance: "Union[Type[Model], Model]", ) -> Optional[str]: self.validate(value) - if value is not None: - if isinstance(value, (str, bytes)): - try: - self.decoder(value) - except Exception: - raise FieldError(f"Value {value!r} is invalid json value.") - if isinstance(value, bytes): - value = value.decode() - else: - value = self.encoder(value) - return value + if value is None: + return None + + if isinstance(value, (str, bytes)): + try: + self.decoder(value) + except Exception: + raise FieldError(f"Value {value!r} is invalid json value.") + if isinstance(value, bytes): + return value.decode() + return value + + try: + from pydantic import BaseModel + + if isinstance(value, BaseModel): + value = value.model_dump() + except ImportError: + pass + + return self.encoder(value) def to_python_value( - self, value: Optional[Union[str, bytes, dict, list]] - ) -> Optional[Union[dict, list]]: + self, value: Optional[Union[T, str, bytes, dict, list]] + ) -> Optional[Union[T, dict, list]]: if isinstance(value, (str, bytes)): try: - return self.decoder(value) + data = self.decoder(value) + + try: + from pydantic._internal._model_construction import ModelMetaclass + + if isinstance(self.field_type, ModelMetaclass) and not isinstance(data, list): + return self.field_type(**data) + except ImportError: + pass + + return data except Exception: raise FieldError( f"Value {value if isinstance(value, str) else value.decode()} is invalid json value."