Skip to content

Commit

Permalink
fix: do not allow SequenceLocation to have negative start/end v…
Browse files Browse the repository at this point in the history
…alues

close #442

* Add model_validators to `Range` and `SequenceLocation`
  • Loading branch information
korikuzma committed Aug 28, 2024
1 parent 4082981 commit 2f46958
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 4 deletions.
68 changes: 64 additions & 4 deletions src/ga4gh/vrs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
module name, e.g., `ga4gh.vrs.models.Allele`
"""
from abc import ABC
from typing import List, Literal, Optional, Union, Dict, Annotated
from typing import List, Literal, Optional, Self, Union, Dict, Annotated
from collections import OrderedDict
from enum import Enum
import inspect
Expand All @@ -27,7 +27,7 @@
from ga4gh.core.pydantic import get_pydantic_root

from canonicaljson import encode_canonical_json
from pydantic import BaseModel, Field, RootModel, StringConstraints, ConfigDict
from pydantic import BaseModel, Field, RootModel, StringConstraints, ConfigDict, model_validator

from ga4gh.core.pydantic import (
getattr_in
Expand Down Expand Up @@ -331,6 +331,26 @@ class Range(RootModel):
min_length=2,
)

@model_validator(mode="after")
def validate_range(self) -> Self:
"""Validate range values
:raises ValueError: If ``root`` does not include at least one integer or if
the first element in ``root`` is greater than the second element in ``root``
"""
if self.root.count(None) == 2:
err_msg = "Must provide at least one integer."
raise ValueError(err_msg)

if all((
self.root[0] is not None,
self.root[1] is not None,
self.root[0] > self.root[1]
)):
err_msg = "The first integer must be less than or equal to the second integer."
raise ValueError(err_msg)

return self

class Residue(RootModel):
"""A character representing a specific residue (i.e., molecular species) or
Expand Down Expand Up @@ -454,15 +474,55 @@ class SequenceLocation(_Ga4ghIdentifiableObject):
)
start: Optional[Union[Range, int]] = Field(
None,
description='The start coordinate or range of the SequenceLocation. The minimum value of this coordinate or range is 0. MUST represent a coordinate or range less than the value of `end`.',
description='The start coordinate or range of the SequenceLocation. The minimum value of this coordinate or range is 0. MUST represent a coordinate or range less than or equal to the value of `end`.',
)
end: Optional[Union[Range, int]] = Field(
None,
description='The end coordinate or range of the SequenceLocation. The minimum value of this coordinate or range is 0. MUST represent a coordinate or range greater than the value of `start`.',
description='The end coordinate or range of the SequenceLocation. The minimum value of this coordinate or range is 0. MUST represent a coordinate or range greater than or equal to the value of `start`.',

)
sequence: Optional[SequenceString] = Field(None, description="The literal sequence encoded by the `sequenceReference` at these coordinates.")

@model_validator(mode="after")
def validate_start_end(self) -> Self:
"""Validate ``start`` and ``end`` fields
:raises ValueError: If ``start`` or ``end`` has a value less than 0 or if
``start`` is greater than ``end``
:return: Sequence Location
"""
def _get_int_values(start_or_end: int | Range | None) -> list[int]:
"""Get list of integers from ``start`` or ``end`` fields
:param start_or_end: ``start`` or ``end`` field
:raises ValueError: If ``start_or_end`` has a value less than 0
:return: List of integer values
"""
int_values = []

if start_or_end is not None:
if isinstance(start_or_end, int):
int_values = [start_or_end]
else:
int_values = [val for val in start_or_end.root if val is not None]

if any(int_val < 0 for int_val in int_values):
err_msg = "The minimum value of `start` or `end` is 0."
raise ValueError(err_msg)

return int_values

start_values = _get_int_values(self.start)
end_values = _get_int_values(self.end)

if start_values and end_values:
for start_val in start_values:
if any(start_val > end_val for end_val in end_values):
err_msg = "`start` must be less than or equal to `end`."
raise ValueError(err_msg)

return self

def ga4gh_serialize_as_version(self, as_version: PrevVrsVersion):
"""This method will return a serialized string following the conventions for
SequenceLocation serialization as defined in the VRS version specified by
Expand Down
18 changes: 18 additions & 0 deletions tests/test_vrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,24 @@
cpb_431012 = models.CisPhasedBlock(**cpb_431012_dict)


@pytest.mark.parametrize(
"vrs_model, expected_err_msg",
[
(lambda: models.Range(root=[None, None]), "Must provide at least one integer."),
(lambda: models.Range(root=[2, 1]), "The first integer must be less than or equal to the second integer."),
(lambda: models.SequenceLocation(sequenceReference=allele_280320.location.sequenceReference, start=-1), "The minimum value of `start` or `end` is 0."),
(lambda: models.SequenceLocation(sequenceReference=allele_280320.location.sequenceReference, end=[-1, 0]), "The minimum value of `start` or `end` is 0."),
(lambda: models.SequenceLocation(sequenceReference=allele_280320.location.sequenceReference, start=1, end=0), "`start` must be less than or equal to `end`."),
(lambda: models.SequenceLocation(sequenceReference=allele_280320.location.sequenceReference, start=[3,4], end=[1,2]), "`start` must be less than or equal to `end`.")
]
)
def test_model_validation_errors(vrs_model, expected_err_msg):
"""Test that invalid VRS models raise errors"""
with pytest.raises(ValueError) as e:
vrs_model()
assert str(e.value.errors()[0]["ctx"]["error"]) == expected_err_msg


def test_vr():
assert a.model_dump(exclude_none=True) == allele_dict
assert is_pydantic_instance(a)
Expand Down

0 comments on commit 2f46958

Please sign in to comment.