From f6efce43704c7cd4068c9d66e9cbf9a062dce00c Mon Sep 17 00:00:00 2001 From: Kori Kuzma Date: Wed, 17 Jul 2024 13:02:31 -0400 Subject: [PATCH] fix: revert Pydantic custom serializer * The custom serializer resulted in `model_dump` to only include minimal fields for ga4gh serialization. This caused issues in downstream apps that leverage FastAPI --- pyproject.toml | 1 + src/ga4gh/core/entity_models.py | 3 +-- src/ga4gh/core/identifiers.py | 4 ++-- src/ga4gh/vrs/models.py | 9 ++++----- tests/validation/test_models.py | 4 +++- 5 files changed, 11 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c3077e4e..5b29541f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "pydantic~=2.1", "bioutils", "requests", + "canonicaljson", ] [project.optional-dependencies] diff --git a/src/ga4gh/core/entity_models.py b/src/ga4gh/core/entity_models.py index bf91f597..196af6bb 100644 --- a/src/ga4gh/core/entity_models.py +++ b/src/ga4gh/core/entity_models.py @@ -13,7 +13,7 @@ from typing import Any, Dict, Annotated, Optional, Union, List from enum import Enum -from pydantic import BaseModel, Field, RootModel, StringConstraints, model_serializer +from pydantic import BaseModel, Field, RootModel, StringConstraints from ga4gh.core import GA4GH_IR_REGEXP @@ -77,7 +77,6 @@ class IRI(RootModel): def __hash__(self): return self.root.__hash__() - @model_serializer(when_used='json') def ga4gh_serialize(self): m = GA4GH_IR_REGEXP.match(self.root) if m is not None: diff --git a/src/ga4gh/core/identifiers.py b/src/ga4gh/core/identifiers.py index 85ea214a..7b026243 100644 --- a/src/ga4gh/core/identifiers.py +++ b/src/ga4gh/core/identifiers.py @@ -14,7 +14,7 @@ For that reason, they are implemented here in one file. """ - +from canonicaljson import encode_canonical_json import contextvars import re from contextlib import ContextDecorator @@ -194,6 +194,6 @@ def ga4gh_serialize(obj: BaseModel, as_version: PrevVrsVersion | None = None) -> PrevVrsVersion.validate(as_version) if as_version is None: - return obj.model_dump_json().encode("utf-8") + return encode_canonical_json(obj.ga4gh_serialize()) else: return obj.ga4gh_serialize_as_version(as_version) diff --git a/src/ga4gh/vrs/models.py b/src/ga4gh/vrs/models.py index da0b3911..3cbe8246 100644 --- a/src/ga4gh/vrs/models.py +++ b/src/ga4gh/vrs/models.py @@ -25,7 +25,8 @@ ) from ga4gh.core.pydantic import get_pydantic_root -from pydantic import BaseModel, Field, RootModel, StringConstraints, model_serializer +from canonicaljson import encode_canonical_json +from pydantic import BaseModel, Field, RootModel, StringConstraints from ga4gh.core.pydantic import ( getattr_in @@ -177,7 +178,7 @@ def _recurse_ga4gh_serialize(obj): elif isinstance(obj, _ValueObject): return obj.ga4gh_serialize() elif isinstance(obj, RootModel): - return _recurse_ga4gh_serialize(obj.model_dump(mode='json')) + return _recurse_ga4gh_serialize(obj.model_dump()) elif isinstance(obj, str): return obj elif isinstance(obj, list): @@ -194,7 +195,6 @@ class _ValueObject(_DomainEntity): def __hash__(self): return self.model_dump_json().__hash__() - @model_serializer(when_used='json') def ga4gh_serialize(self) -> Dict: out = OrderedDict() for k in self.ga4gh.keys: @@ -241,7 +241,7 @@ def compute_digest(self, store=True, as_version: PrevVrsVersion | None = None) - returned following the conventions of the VRS version indicated by ``as_version_``. """ if as_version is None: - digest = sha512t24u(self.model_dump_json().encode("utf-8")) + digest = sha512t24u(encode_canonical_json(self.ga4gh_serialize())) if store: self.digest = digest else: @@ -577,7 +577,6 @@ class CisPhasedBlock(_VariationBase): ) sequenceReference: Optional[SequenceReference] = Field(None, description="An optional Sequence Reference on which all of the in-cis Alleles are found. When defined, this may be used to implicitly define the `sequenceReference` attribute for each of the CisPhasedBlock member Alleles.") - @model_serializer(when_used="json") def ga4gh_serialize(self) -> Dict: out = _ValueObject.ga4gh_serialize(self) out["members"] = sorted(out["members"]) diff --git a/tests/validation/test_models.py b/tests/validation/test_models.py index 50c1ec63..e0737f2f 100644 --- a/tests/validation/test_models.py +++ b/tests/validation/test_models.py @@ -23,7 +23,7 @@ def ga4gh_1_3_serialize(*args, **kwargs): return ga4gh_serialize(*args, **kwargs) fxs = { - "ga4gh_serialize": lambda o: ga4gh_serialize(o).decode() if ga4gh_serialize(o) else None, + "ga4gh_serialize": ga4gh_serialize, "ga4gh_digest": ga4gh_digest, "ga4gh_identify": ga4gh_identify, "ga4gh_1_3_digest": ga4gh_1_3_digest, @@ -59,6 +59,8 @@ def flatten_tests(vts): def test_validation(cls, data, fn, exp): o = getattr(models, cls)(**data) fx = fxs[fn] + if fn == "ga4gh_serialize": + exp = exp.encode("utf-8") assert fx(o) == exp