From 9dc2d4856637d6f24f937f6859592bf8fe9fa06d Mon Sep 17 00:00:00 2001 From: Jasper Hoving Date: Wed, 15 Jan 2025 17:23:30 +0100 Subject: [PATCH] Fix enum types pydantic 2.7 --- .../image/explicit_content/category.py | 4 +-- .../explicit_content_dataclass.py | 7 ++-- .../anonymization/anonymization_dataclass.py | 32 +++++++++++++++---- .../features/text/anonymization/category.py | 8 ++--- .../features/text/moderation/category.py | 8 ++--- .../text/moderation/moderation_dataclass.py | 28 +++++++++++----- .../features/text/moderation/subcategory.py | 11 ++++++- .../test_text_moderation_dataclass.py | 26 +++++++++++++-- edenai_apis/tests/test_providers_errors.py | 2 +- edenai_apis/utils/combine_enums.py | 8 +++++ 10 files changed, 99 insertions(+), 35 deletions(-) create mode 100644 edenai_apis/utils/combine_enums.py diff --git a/edenai_apis/features/image/explicit_content/category.py b/edenai_apis/features/image/explicit_content/category.py index 077c741e..5673ef3e 100644 --- a/edenai_apis/features/image/explicit_content/category.py +++ b/edenai_apis/features/image/explicit_content/category.py @@ -19,7 +19,7 @@ class ACategoryType(str): pass -class CategoryType(ACategoryType, Enum): +class CategoryType(str, Enum): """This enum are used to categorize the explicit content extracted from the text""" Toxic = "Toxic" @@ -37,7 +37,7 @@ def list_available_type(cls): return [category for category in cls] @classmethod - def list_choices(cls) -> Dict["ACategoryType", SubCategoryBase]: + def list_choices(cls) -> Dict[str, SubCategoryBase]: return { cls.Toxic: ToxicSubCategoryType, cls.Content: ContentSubCategoryType, diff --git a/edenai_apis/features/image/explicit_content/explicit_content_dataclass.py b/edenai_apis/features/image/explicit_content/explicit_content_dataclass.py index 21e0111f..e72af33b 100644 --- a/edenai_apis/features/image/explicit_content/explicit_content_dataclass.py +++ b/edenai_apis/features/image/explicit_content/explicit_content_dataclass.py @@ -22,8 +22,11 @@ SafeSubCategoryType, OtherSubCategoryType, ) +from edenai_apis.utils.combine_enums import combine_enums -SubCategoryType = Union[ + +SubCategoryType = combine_enums( + "SubCategoryType", ToxicSubCategoryType, ContentSubCategoryType, SexualSubCategoryType, @@ -33,7 +36,7 @@ HateAndExtremismSubCategoryType, SafeSubCategoryType, OtherSubCategoryType, -] +) class ExplicitItem(BaseModel): diff --git a/edenai_apis/features/text/anonymization/anonymization_dataclass.py b/edenai_apis/features/text/anonymization/anonymization_dataclass.py index 1b313c6f..81c3e77e 100644 --- a/edenai_apis/features/text/anonymization/anonymization_dataclass.py +++ b/edenai_apis/features/text/anonymization/anonymization_dataclass.py @@ -1,5 +1,13 @@ from typing import Dict, Optional, Sequence, Union -from pydantic import BaseModel, ConfigDict, Field, FieldSerializationInfo, field_serializer, field_validator, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + FieldSerializationInfo, + field_serializer, + field_validator, + model_validator, +) from edenai_apis.features.text.anonymization.category import ( CategoryType, @@ -16,8 +24,15 @@ ) from typing import Dict, Optional, Sequence, Union -from pydantic import BaseModel, ConfigDict, Field, FieldSerializationInfo, field_serializer, field_validator, \ - model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + FieldSerializationInfo, + field_serializer, + field_validator, + model_validator, +) from edenai_apis.features.text.anonymization.category import ( CategoryType, @@ -32,8 +47,11 @@ LocationInformationSubCategoryType, OtherSubCategoryType, ) +from edenai_apis.utils.combine_enums import combine_enums -SubCategoryType = Union[ + +SubCategoryType = combine_enums( + "SubCategoryType", FinancialInformationSubCategoryType, PersonalInformationSubCategoryType, IdentificationNumbersSubCategoryType, @@ -42,7 +60,7 @@ DateAndTimeSubCategoryType, LocationInformationSubCategoryType, OtherSubCategoryType, -] +) class AnonymizationEntity(BaseModel): @@ -94,9 +112,9 @@ def round_confidence_score(cls, v): return round(v, 3) return v - @field_serializer('subcategory', mode="plain", when_used="always") + @field_serializer("subcategory", mode="plain", when_used="always") def serialize_subcategory(self, value: SubCategoryType, _: FieldSerializationInfo): - return getattr(value, 'value', None) + return getattr(value, "value", None) class AnonymizationDataClass(BaseModel): diff --git a/edenai_apis/features/text/anonymization/category.py b/edenai_apis/features/text/anonymization/category.py index 6894b778..a6969e06 100644 --- a/edenai_apis/features/text/anonymization/category.py +++ b/edenai_apis/features/text/anonymization/category.py @@ -14,11 +14,7 @@ ) -class ACategoryType(str): - pass - - -class CategoryType(ACategoryType, Enum): +class CategoryType(str, Enum): """This enum are used to categorize the entities extracted from the text.""" PersonalInformation = "PersonalInformation" @@ -35,7 +31,7 @@ def list_available_type(cls): return [category for category in cls] @classmethod - def list_choices(cls) -> Dict["ACategoryType", SubCategoryBase]: + def list_choices(cls) -> Dict[str, SubCategoryBase]: return { cls.PersonalInformation: PersonalInformationSubCategoryType, cls.FinancialInformation: FinancialInformationSubCategoryType, diff --git a/edenai_apis/features/text/moderation/category.py b/edenai_apis/features/text/moderation/category.py index 5c0a77df..feb8345c 100644 --- a/edenai_apis/features/text/moderation/category.py +++ b/edenai_apis/features/text/moderation/category.py @@ -15,11 +15,7 @@ ) -class ACategoryType(str): - pass - - -class CategoryType(ACategoryType, Enum): +class CategoryType(str, Enum): """This enum are used to categorize the explicit content extracted from the text""" Toxic = "Toxic" @@ -37,7 +33,7 @@ def list_available_type(cls): return [category for category in cls] @classmethod - def list_choices(cls) -> Dict["ACategoryType", SubCategoryBase]: + def list_choices(cls) -> Dict[str, SubCategoryBase]: return { cls.Toxic: ToxicSubCategoryType, cls.Content: ContentSubCategoryType, diff --git a/edenai_apis/features/text/moderation/moderation_dataclass.py b/edenai_apis/features/text/moderation/moderation_dataclass.py index e70150e5..7a468d7c 100644 --- a/edenai_apis/features/text/moderation/moderation_dataclass.py +++ b/edenai_apis/features/text/moderation/moderation_dataclass.py @@ -1,7 +1,15 @@ from enum import Enum -from typing import Sequence, Union +from typing import Sequence, Union, Type -from pydantic import BaseModel, ConfigDict, Field, StrictStr, field_validator, FieldSerializationInfo, field_serializer +from pydantic import ( + BaseModel, + ConfigDict, + Field, + StrictStr, + field_validator, + FieldSerializationInfo, + field_serializer, +) from edenai_apis.features.text.moderation.category import ( CategoryType, @@ -17,8 +25,10 @@ SafeSubCategoryType, OtherSubCategoryType, ) +from edenai_apis.utils.combine_enums import combine_enums -SubCategoryType = Union[ +SubCategoryType = combine_enums( + "SubCategoryType", ToxicSubCategoryType, ContentSubCategoryType, SexualSubCategoryType, @@ -28,7 +38,8 @@ HateAndExtremismSubCategoryType, SafeSubCategoryType, OtherSubCategoryType, -] +) + class TextModerationCategoriesMicrosoftEnum(Enum): Category1 = "sexually explicit" Category2 = "sexually suggestive" @@ -43,14 +54,16 @@ class TextModerationItem(BaseModel): likelihood_score: float model_config = ConfigDict(use_enum_values=True) - @field_serializer('subcategory', mode="plain", when_used="always") + + @field_serializer("subcategory", mode="plain", when_used="always") def serialize_subcategory(self, value: SubCategoryType, _: FieldSerializationInfo): - return getattr(value, 'value', None) + return getattr(value, "value", None) + class ModerationDataClass(BaseModel): nsfw_likelihood: int items: Sequence[TextModerationItem] = Field(default_factory=list) - nsfw_likelihood_score : float + nsfw_likelihood_score: float @field_validator("nsfw_likelihood") @classmethod @@ -59,7 +72,6 @@ def check_min_max(cls, value): raise ValueError("Likelihood walue should be between 0 and 5") return value - @staticmethod def calculate_nsfw_likelihood(items: Sequence[TextModerationItem]): if len(items) == 0: diff --git a/edenai_apis/features/text/moderation/subcategory.py b/edenai_apis/features/text/moderation/subcategory.py index c8862df4..168e5a1a 100644 --- a/edenai_apis/features/text/moderation/subcategory.py +++ b/edenai_apis/features/text/moderation/subcategory.py @@ -38,6 +38,7 @@ def choose_label(cls, label: str) -> "SubCategoryBase": f"Unknown label {label}. Only {cls.list_choices().values()} are allowed." ) + class ToxicSubCategoryType(SubCategoryBase, Enum): Insult = "Insult" Obscene = "Obscene" @@ -57,6 +58,7 @@ def list_choices(cls) -> Dict["SubCategoryBase", List[str]]: cls.Toxic: SubCategoryPattern.Toxic.TOXIC, } + class ContentSubCategoryType(SubCategoryBase, Enum): MiddleFinger = "MiddleFinger" PublicSafety = "PublicSafety" @@ -80,6 +82,7 @@ def list_choices(cls) -> Dict["SubCategoryBase", List[str]]: cls.Legal: SubCategoryPattern.Content.LEGAL, } + class SexualSubCategoryType(SubCategoryBase, Enum): SexualActivity = "SexualActivity" SexualSituations = "SexualSituations" @@ -103,6 +106,7 @@ def list_choices(cls) -> Dict["SubCategoryBase", List[str]]: cls.Sexual: SubCategoryPattern.Sexual.SEXUAL, } + class ViolenceSubCategoryType(SubCategoryBase, Enum): GraphicViolenceOrGore = "GraphicViolenceOrGore" PhysicalViolence = "PhysicalViolence" @@ -118,6 +122,7 @@ def list_choices(cls) -> Dict["SubCategoryBase", List[str]]: cls.Violence: SubCategoryPattern.Violence.VIOLENCE, } + class DrugAndAlcoholSubCategoryType(SubCategoryBase, Enum): DrugProducts = "DrugProducts" DrugUse = "DrugUse" @@ -137,6 +142,7 @@ def list_choices(cls) -> Dict["SubCategoryBase", List[str]]: cls.Drinking: SubCategoryPattern.DrugAndAlcohol.DRINKING, } + class FinanceSubCategoryType(SubCategoryBase, Enum): Gambling = "Gambling" Finance = "Finance" @@ -150,6 +156,7 @@ def list_choices(cls) -> Dict["SubCategoryBase", List[str]]: cls.MoneyContent: SubCategoryPattern.Finance.MONEY_CONTENT, } + class HateAndExtremismSubCategoryType(SubCategoryBase, Enum): Hate = "Hate" Harassment = "Harassment" @@ -167,6 +174,7 @@ def list_choices(cls) -> Dict["SubCategoryBase", List[str]]: cls.Racy: SubCategoryPattern.HateAndExtremism.RACY, } + class SafeSubCategoryType(SubCategoryBase, Enum): Safe = "Safe" NotSafe = "NotSafe" @@ -178,6 +186,7 @@ def list_choices(cls) -> Dict["SubCategoryBase", List[str]]: cls.NotSafe: SubCategoryPattern.Safe.NOT_SAFE, } + class OtherSubCategoryType(SubCategoryBase, Enum): Spoof = "Spoof" Religion = "Religion" @@ -191,4 +200,4 @@ def list_choices(cls) -> Dict["SubCategoryBase", List[str]]: cls.Religion: SubCategoryPattern.Other.RELIGION, cls.Offensive: SubCategoryPattern.Other.OFFENSIVE, cls.Other: SubCategoryPattern.Other.OTHER, - } \ No newline at end of file + } diff --git a/edenai_apis/tests/features/dataclasses/test_text_moderation_dataclass.py b/edenai_apis/tests/features/dataclasses/test_text_moderation_dataclass.py index ebf39370..1796c1d4 100644 --- a/edenai_apis/tests/features/dataclasses/test_text_moderation_dataclass.py +++ b/edenai_apis/tests/features/dataclasses/test_text_moderation_dataclass.py @@ -43,8 +43,30 @@ class TestTextModeration: ), ], ) - def test_valid_value_check_min_mac_nsfw(self, nsfw_likelihood, nsfw_likelihood_score): + def test_valid_value_check_min_mac_nsfw( + self, nsfw_likelihood, nsfw_likelihood_score + ): try: - ModerationDataClass(nsfw_likelihood=nsfw_likelihood, items=[], nsfw_likelihood_score=nsfw_likelihood_score) + ModerationDataClass( + nsfw_likelihood=nsfw_likelihood, + items=[], + nsfw_likelihood_score=nsfw_likelihood_score, + ) except ValueError: pytest.fail(f"{nsfw_likelihood} value doesn't raises a ValueError") + + def test_text_moderation_items(self): + + ModerationDataClass( + nsfw_likelihood=0, + items=[ + { + "label": "hi", + "category": "Toxic", + "likelihood": 1, + "subcategory": "Toxic", + "likelihood_score": 0.5, + } + ], + nsfw_likelihood_score=0.0, + ) diff --git a/edenai_apis/tests/test_providers_errors.py b/edenai_apis/tests/test_providers_errors.py index adcbba5a..e7bf58c7 100644 --- a/edenai_apis/tests/test_providers_errors.py +++ b/edenai_apis/tests/test_providers_errors.py @@ -6,7 +6,7 @@ import pytest from apis.amazon.errors import ERRORS as amazon_errors from apis.google.errors import ERRORS as google_errors -from apis.ibm.errors import ERRORS as ibm_errors + from apis.microsoft.errors import ERRORS as microsoft_errors from features.audio.speech_to_text_async.speech_to_text_async_args import ( data_path as audio_data_path, diff --git a/edenai_apis/utils/combine_enums.py b/edenai_apis/utils/combine_enums.py new file mode 100644 index 00000000..b20724d9 --- /dev/null +++ b/edenai_apis/utils/combine_enums.py @@ -0,0 +1,8 @@ +from typing import Type +from enum import Enum + +def combine_enums(name: str, *enum_classes: Type[Enum]) -> Enum: + combined = {} + for enum_class in enum_classes: + combined.update(enum_class.__members__) + return Enum(name, combined)