Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: revert Pydantic custom serializer #435

Merged
merged 6 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies = [
"pydantic~=2.1",
"bioutils",
"requests",
"canonicaljson",
]

[project.optional-dependencies]
Expand Down
3 changes: 1 addition & 2 deletions src/ga4gh/core/entity_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Any, Dict, Annotated, Optional, Union, List
from enum import Enum

from pydantic import BaseModel, Field, RootModel, StringConstraints, model_serializer, ConfigDict
from pydantic import BaseModel, Field, RootModel, StringConstraints, ConfigDict

from ga4gh.core import GA4GH_IR_REGEXP

Expand Down Expand Up @@ -78,7 +78,6 @@ class IRI(RootModel):
def __hash__(self):
return self.root.__hash__()

@model_serializer(when_used='json')
def ga4gh_serialize(self):
m = GA4GH_IR_REGEXP.match(self.root)
if m is not None:
Expand Down
4 changes: 2 additions & 2 deletions src/ga4gh/core/identifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
For that reason, they are implemented here in one file.

"""

from canonicaljson import encode_canonical_json
import contextvars
import re
from contextlib import ContextDecorator
Expand Down Expand Up @@ -194,6 +194,6 @@ def ga4gh_serialize(obj: BaseModel, as_version: PrevVrsVersion | None = None) ->
PrevVrsVersion.validate(as_version)

if as_version is None:
return obj.model_dump_json().encode("utf-8")
return encode_canonical_json(obj.ga4gh_serialize())
else:
return obj.ga4gh_serialize_as_version(as_version)
11 changes: 5 additions & 6 deletions src/ga4gh/vrs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
)
from ga4gh.core.pydantic import get_pydantic_root

from pydantic import BaseModel, Field, RootModel, StringConstraints, model_serializer, ConfigDict
from canonicaljson import encode_canonical_json
from pydantic import BaseModel, Field, RootModel, StringConstraints, ConfigDict

from ga4gh.core.pydantic import (
getattr_in
Expand Down Expand Up @@ -178,7 +179,7 @@ def _recurse_ga4gh_serialize(obj):
elif isinstance(obj, _ValueObject):
return obj.ga4gh_serialize()
elif isinstance(obj, RootModel):
return _recurse_ga4gh_serialize(obj.model_dump(mode='json'))
return _recurse_ga4gh_serialize(obj.model_dump())
elif isinstance(obj, str):
return obj
elif isinstance(obj, list):
Expand All @@ -193,9 +194,8 @@ class _ValueObject(DomainEntity, ABC):
"""

def __hash__(self):
return self.model_dump_json().__hash__()
return encode_canonical_json(self.ga4gh_serialize()).decode("utf-8").__hash__()

@model_serializer(when_used='json')
def ga4gh_serialize(self) -> Dict:
out = OrderedDict()
for k in self.ga4gh.keys:
Expand Down Expand Up @@ -242,7 +242,7 @@ def compute_digest(self, store=True, as_version: PrevVrsVersion | None = None) -
returned following the conventions of the VRS version indicated by ``as_version_``.
"""
if as_version is None:
digest = sha512t24u(self.model_dump_json().encode("utf-8"))
digest = sha512t24u(encode_canonical_json(self.ga4gh_serialize()))
if store:
self.digest = digest
else:
Expand Down Expand Up @@ -580,7 +580,6 @@ class CisPhasedBlock(_VariationBase):
)
sequenceReference: Optional[SequenceReference] = Field(None, description="An optional Sequence Reference on which all of the in-cis Alleles are found. When defined, this may be used to implicitly define the `sequenceReference` attribute for each of the CisPhasedBlock member Alleles.")

@model_serializer(when_used="json")
def ga4gh_serialize(self) -> Dict:
out = _ValueObject.ga4gh_serialize(self)
out["members"] = sorted(out["members"])
Expand Down
4 changes: 3 additions & 1 deletion tests/validation/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def ga4gh_1_3_serialize(*args, **kwargs):
return ga4gh_serialize(*args, **kwargs)

fxs = {
"ga4gh_serialize": lambda o: ga4gh_serialize(o).decode() if ga4gh_serialize(o) else None,
"ga4gh_serialize": ga4gh_serialize,
"ga4gh_digest": ga4gh_digest,
"ga4gh_identify": ga4gh_identify,
"ga4gh_1_3_digest": ga4gh_1_3_digest,
Expand Down Expand Up @@ -60,6 +60,8 @@ def flatten_tests(vts):
def test_validation(cls, data, fn, exp):
o = getattr(models, cls)(**data)
fx = fxs[fn]
if fn == "ga4gh_serialize":
exp = exp.encode("utf-8")
assert fx(o) == exp


Expand Down