diff --git a/src/ga4gh/core/domain_models.py b/src/ga4gh/core/domain_models.py index ee29b568..853da78b 100644 --- a/src/ga4gh/core/domain_models.py +++ b/src/ga4gh/core/domain_models.py @@ -10,6 +10,7 @@ * `import ga4gh.core`, and refer to models using the fully-qualified module name, e.g., `ga4gh.core.domain_models.Gene` """ +from enum import Enum from typing import Literal, Union, List from pydantic import Field, RootModel @@ -17,7 +18,7 @@ from ga4gh.core.entity_models import DomainEntity -class CommonDomainType: +class CommonDomainType(str, Enum): """Define GKS Common Domain Entity types""" PHENOTYPE = "Phenotype" @@ -33,8 +34,8 @@ class Phenotype(DomainEntity): """An observable characteristic or trait of an organism.""" type: Literal["Phenotype"] = Field( - CommonDomainType.PHENOTYPE, - description=f'MUST be "{CommonDomainType.PHENOTYPE}".' + CommonDomainType.PHENOTYPE.value, + description=f'MUST be "{CommonDomainType.PHENOTYPE.value}".' ) @@ -44,8 +45,8 @@ class Disease(DomainEntity): """ type: Literal["Disease"] = Field( - CommonDomainType.DISEASE, - description=f'MUST be "{CommonDomainType.DISEASE}".' + CommonDomainType.DISEASE.value, + description=f'MUST be "{CommonDomainType.DISEASE.value}".' ) @@ -53,8 +54,8 @@ class TraitSet(DomainEntity): """A set of phenotype and/or disease concepts that together constitute a condition.""" type: Literal["TraitSet"] = Field( - CommonDomainType.TRAIT_SET, - description=f'MUST be "{CommonDomainType.TRAIT_SET}".' + CommonDomainType.TRAIT_SET.value, + description=f'MUST be "{CommonDomainType.TRAIT_SET.value}".' ) traits: List[Union[Disease, Phenotype]] = Field( ..., @@ -76,8 +77,8 @@ class TherapeuticAction(DomainEntity): """A therapeutic action taken that is intended to alter or stop a pathologic process.""" type: Literal["TherapeuticAction"] = Field( - CommonDomainType.TR_ACTION, - description=f'MUST be "{CommonDomainType.TR_ACTION}".' + CommonDomainType.TR_ACTION.value, + description=f'MUST be "{CommonDomainType.TR_ACTION.value}".' ) @@ -85,8 +86,8 @@ class TherapeuticAgent(DomainEntity): """An administered therapeutic agent that is intended to alter or stop a pathologic process.""" type: Literal["TherapeuticAgent"] = Field( - CommonDomainType.TR_AGENT, - description=f'MUST be "{CommonDomainType.TR_AGENT}".' + CommonDomainType.TR_AGENT.value, + description=f'MUST be "{CommonDomainType.TR_AGENT.value}".' ) @@ -94,8 +95,8 @@ class TherapeuticSubstituteGroup(DomainEntity): """A group of therapeutic procedures that may be treated as substitutes for one another.""" type: Literal["TherapeuticSubstituteGroup"] = Field( - CommonDomainType.TR_SUB, - description=f'MUST be "{CommonDomainType.TR_SUB}".' + CommonDomainType.TR_SUB.value, + description=f'MUST be "{CommonDomainType.TR_SUB.value}".' ) substitutes: List[Union[TherapeuticAction, TherapeuticAgent]] = Field( ..., @@ -110,8 +111,8 @@ class CombinationTherapy(DomainEntity): """ type: Literal["CombinationTherapy"] = Field( - CommonDomainType.TR_COMB, - description=f'MUST be "{CommonDomainType.TR_COMB}".' + CommonDomainType.TR_COMB.value, + description=f'MUST be "{CommonDomainType.TR_COMB.value}".' ) components: List[Union[TherapeuticSubstituteGroup, TherapeuticAction, TherapeuticAgent]] = Field( ..., @@ -136,6 +137,6 @@ class Gene(DomainEntity): """A basic physical and functional unit of heredity.""" type: Literal["Gene"] = Field( - CommonDomainType.GENE, - description=f'MUST be "{CommonDomainType.GENE}".' + CommonDomainType.GENE.value, + description=f'MUST be "{CommonDomainType.GENE.value}".' ) diff --git a/src/ga4gh/vrs/models.py b/src/ga4gh/vrs/models.py index aa933157..e787694b 100644 --- a/src/ga4gh/vrs/models.py +++ b/src/ga4gh/vrs/models.py @@ -132,7 +132,7 @@ def pydantic_class_refatt_map(): class_keys) -class VrsType: +class VrsType(str, Enum): """Define VRS Types""" LEN_EXPR = "LengthExpression" @@ -370,7 +370,7 @@ class LengthExpression(_ValueObject): """A sequence expressed only by its length.""" type: Literal["LengthExpression"] = Field( - VrsType.LEN_EXPR, description=f'MUST be "{VrsType.LEN_EXPR}"' + VrsType.LEN_EXPR.value, description=f'MUST be "{VrsType.LEN_EXPR.value}"' ) length: Optional[Union[Range, int]] = None @@ -385,7 +385,7 @@ class ReferenceLengthExpression(_ValueObject): """An expression of a length of a sequence from a repeating reference.""" type: Literal["ReferenceLengthExpression"] = Field( - VrsType.REF_LEN_EXPR, description=f'MUST be "{VrsType.REF_LEN_EXPR}"' + VrsType.REF_LEN_EXPR.value, description=f'MUST be "{VrsType.REF_LEN_EXPR.value}"' ) length: Union[Range, int] = Field( ..., description='The number of residues of the expressed sequence.' @@ -409,7 +409,7 @@ class LiteralSequenceExpression(_ValueObject): """An explicit expression of a Sequence.""" type: Literal["LiteralSequenceExpression"] = Field( - VrsType.LIT_SEQ_EXPR, description=f'MUST be "{VrsType.LIT_SEQ_EXPR}"' + VrsType.LIT_SEQ_EXPR.value, description=f'MUST be "{VrsType.LIT_SEQ_EXPR.value}"' ) sequence: SequenceString = Field(..., description='the literal sequence') @@ -430,7 +430,7 @@ class SequenceReference(_ValueObject): model_config = ConfigDict(use_enum_values=True) - type: Literal["SequenceReference"] = Field(VrsType.SEQ_REF, description=f'MUST be "{VrsType.SEQ_REF}"') + type: Literal["SequenceReference"] = Field(VrsType.SEQ_REF.value, description=f'MUST be "{VrsType.SEQ_REF.value}"') refgetAccession: Annotated[str, StringConstraints(pattern=r'^SQ.[0-9A-Za-z_\-]{32}$')] = Field( ..., description='A `GA4GH RefGet ` identifier for the referenced sequence, using the sha512t24u digest.', @@ -448,7 +448,7 @@ class ga4gh(_ValueObject.ga4gh): class SequenceLocation(_Ga4ghIdentifiableObject): """A `Location` defined by an interval on a referenced `Sequence`.""" - type: Literal["SequenceLocation"] = Field(VrsType.SEQ_LOC, description=f'MUST be "{VrsType.SEQ_LOC}"') + type: Literal["SequenceLocation"] = Field(VrsType.SEQ_LOC.value, description=f'MUST be "{VrsType.SEQ_LOC.value}"') sequenceReference: Optional[Union[IRI, SequenceReference]] = Field( None, description='A reference to a `Sequence` on which the location is defined.' ) @@ -529,7 +529,7 @@ class _VariationBase(_Ga4ghIdentifiableObject, ABC): class Allele(_VariationBase): """The state of a molecule at a `Location`.""" - type: Literal["Allele"] = Field(VrsType.ALLELE, description=f'MUST be "{VrsType.ALLELE}"') + type: Literal["Allele"] = Field(VrsType.ALLELE.value, description=f'MUST be "{VrsType.ALLELE.value}"') location: Union[IRI, SequenceLocation] = Field( ..., description='The location of the Allele' ) @@ -572,7 +572,7 @@ class ga4gh(_Ga4ghIdentifiableObject.ga4gh): class CisPhasedBlock(_VariationBase): """An ordered set of co-occurring `Variation` on the same molecule.""" - type: Literal["CisPhasedBlock"] = Field(VrsType.CIS_PHASED_BLOCK, description=f'MUST be "{VrsType.CIS_PHASED_BLOCK}"') + type: Literal["CisPhasedBlock"] = Field(VrsType.CIS_PHASED_BLOCK.value, description=f'MUST be "{VrsType.CIS_PHASED_BLOCK.value}"') members: List[Union[Allele, IRI]] = Field( ..., description='A list of `Alleles` that are found in-cis on a shared molecule.', @@ -605,7 +605,7 @@ class Adjacency(_VariationBase): potentially with an intervening linker sequence. """ - type: Literal["Adjacency"] = Field(VrsType.ADJACENCY, description=f'MUST be "{VrsType.ADJACENCY}"') + type: Literal["Adjacency"] = Field(VrsType.ADJACENCY.value, description=f'MUST be "{VrsType.ADJACENCY.value}"') adjoinedSequences: List[Union[IRI, SequenceLocation]] = Field( ..., description="The terminal sequence or pair of adjoined sequences that defines in the adjacency.", @@ -633,7 +633,7 @@ class SequenceTerminus(_VariationBase): is not allowed and it removes the unnecessary array structure. """ - type: Literal["SequenceTerminus"] = Field(VrsType.SEQ_TERMINUS, description=f'MUST be "{VrsType.SEQ_TERMINUS}"') + type: Literal["SequenceTerminus"] = Field(VrsType.SEQ_TERMINUS.value, description=f'MUST be "{VrsType.SEQ_TERMINUS.value}"') location: Union[IRI, SequenceLocation] = Field(..., description="The location of the terminus.") class ga4gh(_Ga4ghIdentifiableObject.ga4gh): @@ -649,7 +649,7 @@ class DerivativeSequence(_VariationBase): sequence composed from multiple sequence adjacencies. """ - type: Literal["DerivativeSequence"] = Field(VrsType.DERIVATIVE_SEQ, description=f'MUST be "{VrsType.DERIVATIVE_SEQ}"') + type: Literal["DerivativeSequence"] = Field(VrsType.DERIVATIVE_SEQ.value, description=f'MUST be "{VrsType.DERIVATIVE_SEQ.value}"') components: List[Union[IRI, Adjacency, Allele, SequenceTerminus, CisPhasedBlock]] = Field( ..., description="The sequence components that make up the derivative sequence.", @@ -683,7 +683,7 @@ class CopyNumberCount(_CopyNumber): (e.g. genome, cell, etc.). """ - type: Literal["CopyNumberCount"] = Field(VrsType.CN_COUNT, description=f'MUST be "{VrsType.CN_COUNT}"') + type: Literal["CopyNumberCount"] = Field(VrsType.CN_COUNT.value, description=f'MUST be "{VrsType.CN_COUNT.value}"') copies: Union[Range, int] = Field( ..., description='The integral number of copies of the subject in a system' ) @@ -704,7 +704,7 @@ class CopyNumberChange(_CopyNumber): model_config = ConfigDict(use_enum_values=True) - type: Literal["CopyNumberChange"] = Field(VrsType.CN_CHANGE, description=f'MUST be "{VrsType.CN_CHANGE}"') + type: Literal["CopyNumberChange"] = Field(VrsType.CN_CHANGE.value, description=f'MUST be "{VrsType.CN_CHANGE.value}"') copyChange: CopyChange = Field( ..., description='MUST be one of "efo:0030069" (complete genomic loss), "efo:0020073" (high-level loss), "efo:0030068" (low-level loss), "efo:0030067" (loss), "efo:0030064" (regional base ploidy), "efo:0030070" (gain), "efo:0030071" (low-level gain), "efo:0030072" (high-level gain).', diff --git a/tests/validation/test_models.py b/tests/validation/test_models.py index e86ebbe9..f09f5db8 100644 --- a/tests/validation/test_models.py +++ b/tests/validation/test_models.py @@ -101,18 +101,18 @@ def test_prev_vrs_version(): def test_valid_types(): """Ensure that type enums values correct. Values should correspond to class""" - for gks_models, gks_type in [(models, VrsType), (domain_models, CommonDomainType)]: - for attr, value in gks_type.__dict__.items(): - if not attr.startswith("__"): - if hasattr(gks_models, value): - gks_class = getattr(gks_models, value) - try: - assert gks_class(type=value) - except ValidationError as e: - found_type_mismatch = False - for error in e.errors(): - if error["loc"] == ("type",): - found_type_mismatch = True - assert not found_type_mismatch, f"Found mismatch in type literal: {value} vs {error['ctx']['expected']}" - else: - assert False, f"{str(gks_models)} class not found: {value}" + for gks_models, gks_enum in [(models, VrsType), (domain_models, CommonDomainType)]: + for enum_val in gks_enum.__members__.values(): + enum_val = enum_val.value + if hasattr(gks_models, enum_val): + gks_class = getattr(gks_models, enum_val) + try: + assert gks_class(type=enum_val) + except ValidationError as e: + found_type_mismatch = False + for error in e.errors(): + if error["loc"] == ("type",): + found_type_mismatch = True + assert not found_type_mismatch, f"Found mismatch in type literal: {enum_val} vs {error['ctx']['expected']}" + else: + assert False, f"{str(gks_models)} class not found: {enum_val}"