Skip to content

Commit

Permalink
revert back to str, Enum
Browse files Browse the repository at this point in the history
  • Loading branch information
korikuzma committed Jul 18, 2024
1 parent 51874b9 commit ba7e6c0
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 45 deletions.
35 changes: 18 additions & 17 deletions src/ga4gh/core/domain_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
* `import ga4gh.core`, and refer to models using the fully-qualified
module name, e.g., `ga4gh.core.domain_models.Gene`
"""
from enum import Enum
from typing import Literal, Union, List

from pydantic import Field, RootModel

from ga4gh.core.entity_models import DomainEntity


class CommonDomainType:
class CommonDomainType(str, Enum):
"""Define GKS Common Domain Entity types"""

PHENOTYPE = "Phenotype"
Expand All @@ -33,8 +34,8 @@ class Phenotype(DomainEntity):
"""An observable characteristic or trait of an organism."""

type: Literal["Phenotype"] = Field(
CommonDomainType.PHENOTYPE,
description=f'MUST be "{CommonDomainType.PHENOTYPE}".'
CommonDomainType.PHENOTYPE.value,
description=f'MUST be "{CommonDomainType.PHENOTYPE.value}".'
)


Expand All @@ -44,17 +45,17 @@ class Disease(DomainEntity):
"""

type: Literal["Disease"] = Field(
CommonDomainType.DISEASE,
description=f'MUST be "{CommonDomainType.DISEASE}".'
CommonDomainType.DISEASE.value,
description=f'MUST be "{CommonDomainType.DISEASE.value}".'
)


class TraitSet(DomainEntity):
"""A set of phenotype and/or disease concepts that together constitute a condition."""

type: Literal["TraitSet"] = Field(
CommonDomainType.TRAIT_SET,
description=f'MUST be "{CommonDomainType.TRAIT_SET}".'
CommonDomainType.TRAIT_SET.value,
description=f'MUST be "{CommonDomainType.TRAIT_SET.value}".'
)
traits: List[Union[Disease, Phenotype]] = Field(
...,
Expand All @@ -76,26 +77,26 @@ class TherapeuticAction(DomainEntity):
"""A therapeutic action taken that is intended to alter or stop a pathologic process."""

type: Literal["TherapeuticAction"] = Field(
CommonDomainType.TR_ACTION,
description=f'MUST be "{CommonDomainType.TR_ACTION}".'
CommonDomainType.TR_ACTION.value,
description=f'MUST be "{CommonDomainType.TR_ACTION.value}".'
)


class TherapeuticAgent(DomainEntity):
"""An administered therapeutic agent that is intended to alter or stop a pathologic process."""

type: Literal["TherapeuticAgent"] = Field(
CommonDomainType.TR_AGENT,
description=f'MUST be "{CommonDomainType.TR_AGENT}".'
CommonDomainType.TR_AGENT.value,
description=f'MUST be "{CommonDomainType.TR_AGENT.value}".'
)


class TherapeuticSubstituteGroup(DomainEntity):
"""A group of therapeutic procedures that may be treated as substitutes for one another."""

type: Literal["TherapeuticSubstituteGroup"] = Field(
CommonDomainType.TR_SUB,
description=f'MUST be "{CommonDomainType.TR_SUB}".'
CommonDomainType.TR_SUB.value,
description=f'MUST be "{CommonDomainType.TR_SUB.value}".'
)
substitutes: List[Union[TherapeuticAction, TherapeuticAgent]] = Field(
...,
Expand All @@ -110,8 +111,8 @@ class CombinationTherapy(DomainEntity):
"""

type: Literal["CombinationTherapy"] = Field(
CommonDomainType.TR_COMB,
description=f'MUST be "{CommonDomainType.TR_COMB}".'
CommonDomainType.TR_COMB.value,
description=f'MUST be "{CommonDomainType.TR_COMB.value}".'
)
components: List[Union[TherapeuticSubstituteGroup, TherapeuticAction, TherapeuticAgent]] = Field(
...,
Expand All @@ -136,6 +137,6 @@ class Gene(DomainEntity):
"""A basic physical and functional unit of heredity."""

type: Literal["Gene"] = Field(
CommonDomainType.GENE,
description=f'MUST be "{CommonDomainType.GENE}".'
CommonDomainType.GENE.value,
description=f'MUST be "{CommonDomainType.GENE.value}".'
)
26 changes: 13 additions & 13 deletions src/ga4gh/vrs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def pydantic_class_refatt_map():
class_keys)


class VrsType:
class VrsType(str, Enum):
"""Define VRS Types"""

LEN_EXPR = "LengthExpression"
Expand Down Expand Up @@ -370,7 +370,7 @@ class LengthExpression(_ValueObject):
"""A sequence expressed only by its length."""

type: Literal["LengthExpression"] = Field(
VrsType.LEN_EXPR, description=f'MUST be "{VrsType.LEN_EXPR}"'
VrsType.LEN_EXPR.value, description=f'MUST be "{VrsType.LEN_EXPR.value}"'
)
length: Optional[Union[Range, int]] = None

Expand All @@ -385,7 +385,7 @@ class ReferenceLengthExpression(_ValueObject):
"""An expression of a length of a sequence from a repeating reference."""

type: Literal["ReferenceLengthExpression"] = Field(
VrsType.REF_LEN_EXPR, description=f'MUST be "{VrsType.REF_LEN_EXPR}"'
VrsType.REF_LEN_EXPR.value, description=f'MUST be "{VrsType.REF_LEN_EXPR.value}"'
)
length: Union[Range, int] = Field(
..., description='The number of residues of the expressed sequence.'
Expand All @@ -409,7 +409,7 @@ class LiteralSequenceExpression(_ValueObject):
"""An explicit expression of a Sequence."""

type: Literal["LiteralSequenceExpression"] = Field(
VrsType.LIT_SEQ_EXPR, description=f'MUST be "{VrsType.LIT_SEQ_EXPR}"'
VrsType.LIT_SEQ_EXPR.value, description=f'MUST be "{VrsType.LIT_SEQ_EXPR.value}"'
)
sequence: SequenceString = Field(..., description='the literal sequence')

Expand All @@ -430,7 +430,7 @@ class SequenceReference(_ValueObject):

model_config = ConfigDict(use_enum_values=True)

type: Literal["SequenceReference"] = Field(VrsType.SEQ_REF, description=f'MUST be "{VrsType.SEQ_REF}"')
type: Literal["SequenceReference"] = Field(VrsType.SEQ_REF.value, description=f'MUST be "{VrsType.SEQ_REF.value}"')
refgetAccession: Annotated[str, StringConstraints(pattern=r'^SQ.[0-9A-Za-z_\-]{32}$')] = Field(
...,
description='A `GA4GH RefGet <http://samtools.github.io/hts-specs/refget.html>` identifier for the referenced sequence, using the sha512t24u digest.',
Expand All @@ -448,7 +448,7 @@ class ga4gh(_ValueObject.ga4gh):
class SequenceLocation(_Ga4ghIdentifiableObject):
"""A `Location` defined by an interval on a referenced `Sequence`."""

type: Literal["SequenceLocation"] = Field(VrsType.SEQ_LOC, description=f'MUST be "{VrsType.SEQ_LOC}"')
type: Literal["SequenceLocation"] = Field(VrsType.SEQ_LOC.value, description=f'MUST be "{VrsType.SEQ_LOC.value}"')
sequenceReference: Optional[Union[IRI, SequenceReference]] = Field(
None, description='A reference to a `Sequence` on which the location is defined.'
)
Expand Down Expand Up @@ -529,7 +529,7 @@ class _VariationBase(_Ga4ghIdentifiableObject, ABC):
class Allele(_VariationBase):
"""The state of a molecule at a `Location`."""

type: Literal["Allele"] = Field(VrsType.ALLELE, description=f'MUST be "{VrsType.ALLELE}"')
type: Literal["Allele"] = Field(VrsType.ALLELE.value, description=f'MUST be "{VrsType.ALLELE.value}"')
location: Union[IRI, SequenceLocation] = Field(
..., description='The location of the Allele'
)
Expand Down Expand Up @@ -572,7 +572,7 @@ class ga4gh(_Ga4ghIdentifiableObject.ga4gh):
class CisPhasedBlock(_VariationBase):
"""An ordered set of co-occurring `Variation` on the same molecule."""

type: Literal["CisPhasedBlock"] = Field(VrsType.CIS_PHASED_BLOCK, description=f'MUST be "{VrsType.CIS_PHASED_BLOCK}"')
type: Literal["CisPhasedBlock"] = Field(VrsType.CIS_PHASED_BLOCK.value, description=f'MUST be "{VrsType.CIS_PHASED_BLOCK.value}"')
members: List[Union[Allele, IRI]] = Field(
...,
description='A list of `Alleles` that are found in-cis on a shared molecule.',
Expand Down Expand Up @@ -605,7 +605,7 @@ class Adjacency(_VariationBase):
potentially with an intervening linker sequence.
"""

type: Literal["Adjacency"] = Field(VrsType.ADJACENCY, description=f'MUST be "{VrsType.ADJACENCY}"')
type: Literal["Adjacency"] = Field(VrsType.ADJACENCY.value, description=f'MUST be "{VrsType.ADJACENCY.value}"')
adjoinedSequences: List[Union[IRI, SequenceLocation]] = Field(
...,
description="The terminal sequence or pair of adjoined sequences that defines in the adjacency.",
Expand Down Expand Up @@ -633,7 +633,7 @@ class SequenceTerminus(_VariationBase):
is not allowed and it removes the unnecessary array structure.
"""

type: Literal["SequenceTerminus"] = Field(VrsType.SEQ_TERMINUS, description=f'MUST be "{VrsType.SEQ_TERMINUS}"')
type: Literal["SequenceTerminus"] = Field(VrsType.SEQ_TERMINUS.value, description=f'MUST be "{VrsType.SEQ_TERMINUS.value}"')
location: Union[IRI, SequenceLocation] = Field(..., description="The location of the terminus.")

class ga4gh(_Ga4ghIdentifiableObject.ga4gh):
Expand All @@ -649,7 +649,7 @@ class DerivativeSequence(_VariationBase):
sequence composed from multiple sequence adjacencies.
"""

type: Literal["DerivativeSequence"] = Field(VrsType.DERIVATIVE_SEQ, description=f'MUST be "{VrsType.DERIVATIVE_SEQ}"')
type: Literal["DerivativeSequence"] = Field(VrsType.DERIVATIVE_SEQ.value, description=f'MUST be "{VrsType.DERIVATIVE_SEQ.value}"')
components: List[Union[IRI, Adjacency, Allele, SequenceTerminus, CisPhasedBlock]] = Field(
...,
description="The sequence components that make up the derivative sequence.",
Expand Down Expand Up @@ -683,7 +683,7 @@ class CopyNumberCount(_CopyNumber):
(e.g. genome, cell, etc.).
"""

type: Literal["CopyNumberCount"] = Field(VrsType.CN_COUNT, description=f'MUST be "{VrsType.CN_COUNT}"')
type: Literal["CopyNumberCount"] = Field(VrsType.CN_COUNT.value, description=f'MUST be "{VrsType.CN_COUNT.value}"')
copies: Union[Range, int] = Field(
..., description='The integral number of copies of the subject in a system'
)
Expand All @@ -704,7 +704,7 @@ class CopyNumberChange(_CopyNumber):

model_config = ConfigDict(use_enum_values=True)

type: Literal["CopyNumberChange"] = Field(VrsType.CN_CHANGE, description=f'MUST be "{VrsType.CN_CHANGE}"')
type: Literal["CopyNumberChange"] = Field(VrsType.CN_CHANGE.value, description=f'MUST be "{VrsType.CN_CHANGE.value}"')
copyChange: CopyChange = Field(
...,
description='MUST be one of "efo:0030069" (complete genomic loss), "efo:0020073" (high-level loss), "efo:0030068" (low-level loss), "efo:0030067" (loss), "efo:0030064" (regional base ploidy), "efo:0030070" (gain), "efo:0030071" (low-level gain), "efo:0030072" (high-level gain).',
Expand Down
30 changes: 15 additions & 15 deletions tests/validation/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,18 @@ def test_prev_vrs_version():

def test_valid_types():
"""Ensure that type enums values correct. Values should correspond to class"""
for gks_models, gks_type in [(models, VrsType), (domain_models, CommonDomainType)]:
for attr, value in gks_type.__dict__.items():
if not attr.startswith("__"):
if hasattr(gks_models, value):
gks_class = getattr(gks_models, value)
try:
assert gks_class(type=value)
except ValidationError as e:
found_type_mismatch = False
for error in e.errors():
if error["loc"] == ("type",):
found_type_mismatch = True
assert not found_type_mismatch, f"Found mismatch in type literal: {value} vs {error['ctx']['expected']}"
else:
assert False, f"{str(gks_models)} class not found: {value}"
for gks_models, gks_enum in [(models, VrsType), (domain_models, CommonDomainType)]:
for enum_val in gks_enum.__members__.values():
enum_val = enum_val.value
if hasattr(gks_models, enum_val):
gks_class = getattr(gks_models, enum_val)
try:
assert gks_class(type=enum_val)
except ValidationError as e:
found_type_mismatch = False
for error in e.errors():
if error["loc"] == ("type",):
found_type_mismatch = True
assert not found_type_mismatch, f"Found mismatch in type literal: {enum_val} vs {error['ctx']['expected']}"
else:
assert False, f"{str(gks_models)} class not found: {enum_val}"

0 comments on commit ba7e6c0

Please sign in to comment.