diff --git a/edenai_apis/features/image/explicit_content/category.py b/edenai_apis/features/image/explicit_content/category.py index 077c741e..5673ef3e 100644 --- a/edenai_apis/features/image/explicit_content/category.py +++ b/edenai_apis/features/image/explicit_content/category.py @@ -19,7 +19,7 @@ class ACategoryType(str): pass -class CategoryType(ACategoryType, Enum): +class CategoryType(str, Enum): """This enum are used to categorize the explicit content extracted from the text""" Toxic = "Toxic" @@ -37,7 +37,7 @@ def list_available_type(cls): return [category for category in cls] @classmethod - def list_choices(cls) -> Dict["ACategoryType", SubCategoryBase]: + def list_choices(cls) -> Dict[str, SubCategoryBase]: return { cls.Toxic: ToxicSubCategoryType, cls.Content: ContentSubCategoryType, diff --git a/edenai_apis/features/image/explicit_content/explicit_content_dataclass.py b/edenai_apis/features/image/explicit_content/explicit_content_dataclass.py index 21e0111f..e72af33b 100644 --- a/edenai_apis/features/image/explicit_content/explicit_content_dataclass.py +++ b/edenai_apis/features/image/explicit_content/explicit_content_dataclass.py @@ -22,8 +22,11 @@ SafeSubCategoryType, OtherSubCategoryType, ) +from edenai_apis.utils.combine_enums import combine_enums -SubCategoryType = Union[ + +SubCategoryType = combine_enums( + "SubCategoryType", ToxicSubCategoryType, ContentSubCategoryType, SexualSubCategoryType, @@ -33,7 +36,7 @@ HateAndExtremismSubCategoryType, SafeSubCategoryType, OtherSubCategoryType, -] +) class ExplicitItem(BaseModel): diff --git a/edenai_apis/features/text/anonymization/anonymization_dataclass.py b/edenai_apis/features/text/anonymization/anonymization_dataclass.py index 1b313c6f..81c3e77e 100644 --- a/edenai_apis/features/text/anonymization/anonymization_dataclass.py +++ b/edenai_apis/features/text/anonymization/anonymization_dataclass.py @@ -1,5 +1,13 @@ from typing import Dict, Optional, Sequence, Union -from pydantic import BaseModel, ConfigDict, Field, FieldSerializationInfo, field_serializer, field_validator, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + FieldSerializationInfo, + field_serializer, + field_validator, + model_validator, +) from edenai_apis.features.text.anonymization.category import ( CategoryType, @@ -16,8 +24,15 @@ ) from typing import Dict, Optional, Sequence, Union -from pydantic import BaseModel, ConfigDict, Field, FieldSerializationInfo, field_serializer, field_validator, \ - model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + FieldSerializationInfo, + field_serializer, + field_validator, + model_validator, +) from edenai_apis.features.text.anonymization.category import ( CategoryType, @@ -32,8 +47,11 @@ LocationInformationSubCategoryType, OtherSubCategoryType, ) +from edenai_apis.utils.combine_enums import combine_enums -SubCategoryType = Union[ + +SubCategoryType = combine_enums( + "SubCategoryType", FinancialInformationSubCategoryType, PersonalInformationSubCategoryType, IdentificationNumbersSubCategoryType, @@ -42,7 +60,7 @@ DateAndTimeSubCategoryType, LocationInformationSubCategoryType, OtherSubCategoryType, -] +) class AnonymizationEntity(BaseModel): @@ -94,9 +112,9 @@ def round_confidence_score(cls, v): return round(v, 3) return v - @field_serializer('subcategory', mode="plain", when_used="always") + @field_serializer("subcategory", mode="plain", when_used="always") def serialize_subcategory(self, value: SubCategoryType, _: FieldSerializationInfo): - return getattr(value, 'value', None) + return getattr(value, "value", None) class AnonymizationDataClass(BaseModel): diff --git a/edenai_apis/features/text/anonymization/category.py b/edenai_apis/features/text/anonymization/category.py index 6894b778..a6969e06 100644 --- a/edenai_apis/features/text/anonymization/category.py +++ b/edenai_apis/features/text/anonymization/category.py @@ -14,11 +14,7 @@ ) -class ACategoryType(str): - pass - - -class CategoryType(ACategoryType, Enum): +class CategoryType(str, Enum): """This enum are used to categorize the entities extracted from the text.""" PersonalInformation = "PersonalInformation" @@ -35,7 +31,7 @@ def list_available_type(cls): return [category for category in cls] @classmethod - def list_choices(cls) -> Dict["ACategoryType", SubCategoryBase]: + def list_choices(cls) -> Dict[str, SubCategoryBase]: return { cls.PersonalInformation: PersonalInformationSubCategoryType, cls.FinancialInformation: FinancialInformationSubCategoryType, diff --git a/edenai_apis/features/text/moderation/category.py b/edenai_apis/features/text/moderation/category.py index 5c0a77df..feb8345c 100644 --- a/edenai_apis/features/text/moderation/category.py +++ b/edenai_apis/features/text/moderation/category.py @@ -15,11 +15,7 @@ ) -class ACategoryType(str): - pass - - -class CategoryType(ACategoryType, Enum): +class CategoryType(str, Enum): """This enum are used to categorize the explicit content extracted from the text""" Toxic = "Toxic" @@ -37,7 +33,7 @@ def list_available_type(cls): return [category for category in cls] @classmethod - def list_choices(cls) -> Dict["ACategoryType", SubCategoryBase]: + def list_choices(cls) -> Dict[str, SubCategoryBase]: return { cls.Toxic: ToxicSubCategoryType, cls.Content: ContentSubCategoryType, diff --git a/edenai_apis/features/text/moderation/moderation_dataclass.py b/edenai_apis/features/text/moderation/moderation_dataclass.py index e70150e5..7a468d7c 100644 --- a/edenai_apis/features/text/moderation/moderation_dataclass.py +++ b/edenai_apis/features/text/moderation/moderation_dataclass.py @@ -1,7 +1,15 @@ from enum import Enum -from typing import Sequence, Union +from typing import Sequence, Union, Type -from pydantic import BaseModel, ConfigDict, Field, StrictStr, field_validator, FieldSerializationInfo, field_serializer +from pydantic import ( + BaseModel, + ConfigDict, + Field, + StrictStr, + field_validator, + FieldSerializationInfo, + field_serializer, +) from edenai_apis.features.text.moderation.category import ( CategoryType, @@ -17,8 +25,10 @@ SafeSubCategoryType, OtherSubCategoryType, ) +from edenai_apis.utils.combine_enums import combine_enums -SubCategoryType = Union[ +SubCategoryType = combine_enums( + "SubCategoryType", ToxicSubCategoryType, ContentSubCategoryType, SexualSubCategoryType, @@ -28,7 +38,8 @@ HateAndExtremismSubCategoryType, SafeSubCategoryType, OtherSubCategoryType, -] +) + class TextModerationCategoriesMicrosoftEnum(Enum): Category1 = "sexually explicit" Category2 = "sexually suggestive" @@ -43,14 +54,16 @@ class TextModerationItem(BaseModel): likelihood_score: float model_config = ConfigDict(use_enum_values=True) - @field_serializer('subcategory', mode="plain", when_used="always") + + @field_serializer("subcategory", mode="plain", when_used="always") def serialize_subcategory(self, value: SubCategoryType, _: FieldSerializationInfo): - return getattr(value, 'value', None) + return getattr(value, "value", None) + class ModerationDataClass(BaseModel): nsfw_likelihood: int items: Sequence[TextModerationItem] = Field(default_factory=list) - nsfw_likelihood_score : float + nsfw_likelihood_score: float @field_validator("nsfw_likelihood") @classmethod @@ -59,7 +72,6 @@ def check_min_max(cls, value): raise ValueError("Likelihood walue should be between 0 and 5") return value - @staticmethod def calculate_nsfw_likelihood(items: Sequence[TextModerationItem]): if len(items) == 0: diff --git a/edenai_apis/features/text/moderation/subcategory.py b/edenai_apis/features/text/moderation/subcategory.py index c8862df4..168e5a1a 100644 --- a/edenai_apis/features/text/moderation/subcategory.py +++ b/edenai_apis/features/text/moderation/subcategory.py @@ -38,6 +38,7 @@ def choose_label(cls, label: str) -> "SubCategoryBase": f"Unknown label {label}. Only {cls.list_choices().values()} are allowed." ) + class ToxicSubCategoryType(SubCategoryBase, Enum): Insult = "Insult" Obscene = "Obscene" @@ -57,6 +58,7 @@ def list_choices(cls) -> Dict["SubCategoryBase", List[str]]: cls.Toxic: SubCategoryPattern.Toxic.TOXIC, } + class ContentSubCategoryType(SubCategoryBase, Enum): MiddleFinger = "MiddleFinger" PublicSafety = "PublicSafety" @@ -80,6 +82,7 @@ def list_choices(cls) -> Dict["SubCategoryBase", List[str]]: cls.Legal: SubCategoryPattern.Content.LEGAL, } + class SexualSubCategoryType(SubCategoryBase, Enum): SexualActivity = "SexualActivity" SexualSituations = "SexualSituations" @@ -103,6 +106,7 @@ def list_choices(cls) -> Dict["SubCategoryBase", List[str]]: cls.Sexual: SubCategoryPattern.Sexual.SEXUAL, } + class ViolenceSubCategoryType(SubCategoryBase, Enum): GraphicViolenceOrGore = "GraphicViolenceOrGore" PhysicalViolence = "PhysicalViolence" @@ -118,6 +122,7 @@ def list_choices(cls) -> Dict["SubCategoryBase", List[str]]: cls.Violence: SubCategoryPattern.Violence.VIOLENCE, } + class DrugAndAlcoholSubCategoryType(SubCategoryBase, Enum): DrugProducts = "DrugProducts" DrugUse = "DrugUse" @@ -137,6 +142,7 @@ def list_choices(cls) -> Dict["SubCategoryBase", List[str]]: cls.Drinking: SubCategoryPattern.DrugAndAlcohol.DRINKING, } + class FinanceSubCategoryType(SubCategoryBase, Enum): Gambling = "Gambling" Finance = "Finance" @@ -150,6 +156,7 @@ def list_choices(cls) -> Dict["SubCategoryBase", List[str]]: cls.MoneyContent: SubCategoryPattern.Finance.MONEY_CONTENT, } + class HateAndExtremismSubCategoryType(SubCategoryBase, Enum): Hate = "Hate" Harassment = "Harassment" @@ -167,6 +174,7 @@ def list_choices(cls) -> Dict["SubCategoryBase", List[str]]: cls.Racy: SubCategoryPattern.HateAndExtremism.RACY, } + class SafeSubCategoryType(SubCategoryBase, Enum): Safe = "Safe" NotSafe = "NotSafe" @@ -178,6 +186,7 @@ def list_choices(cls) -> Dict["SubCategoryBase", List[str]]: cls.NotSafe: SubCategoryPattern.Safe.NOT_SAFE, } + class OtherSubCategoryType(SubCategoryBase, Enum): Spoof = "Spoof" Religion = "Religion" @@ -191,4 +200,4 @@ def list_choices(cls) -> Dict["SubCategoryBase", List[str]]: cls.Religion: SubCategoryPattern.Other.RELIGION, cls.Offensive: SubCategoryPattern.Other.OFFENSIVE, cls.Other: SubCategoryPattern.Other.OTHER, - } \ No newline at end of file + } diff --git a/edenai_apis/tests/features/dataclasses/test_text_moderation_dataclass.py b/edenai_apis/tests/features/dataclasses/test_text_moderation_dataclass.py index ebf39370..1796c1d4 100644 --- a/edenai_apis/tests/features/dataclasses/test_text_moderation_dataclass.py +++ b/edenai_apis/tests/features/dataclasses/test_text_moderation_dataclass.py @@ -43,8 +43,30 @@ class TestTextModeration: ), ], ) - def test_valid_value_check_min_mac_nsfw(self, nsfw_likelihood, nsfw_likelihood_score): + def test_valid_value_check_min_mac_nsfw( + self, nsfw_likelihood, nsfw_likelihood_score + ): try: - ModerationDataClass(nsfw_likelihood=nsfw_likelihood, items=[], nsfw_likelihood_score=nsfw_likelihood_score) + ModerationDataClass( + nsfw_likelihood=nsfw_likelihood, + items=[], + nsfw_likelihood_score=nsfw_likelihood_score, + ) except ValueError: pytest.fail(f"{nsfw_likelihood} value doesn't raises a ValueError") + + def test_text_moderation_items(self): + + ModerationDataClass( + nsfw_likelihood=0, + items=[ + { + "label": "hi", + "category": "Toxic", + "likelihood": 1, + "subcategory": "Toxic", + "likelihood_score": 0.5, + } + ], + nsfw_likelihood_score=0.0, + ) diff --git a/edenai_apis/tests/test_providers_errors.py b/edenai_apis/tests/test_providers_errors.py index adcbba5a..e7bf58c7 100644 --- a/edenai_apis/tests/test_providers_errors.py +++ b/edenai_apis/tests/test_providers_errors.py @@ -6,7 +6,7 @@ import pytest from apis.amazon.errors import ERRORS as amazon_errors from apis.google.errors import ERRORS as google_errors -from apis.ibm.errors import ERRORS as ibm_errors + from apis.microsoft.errors import ERRORS as microsoft_errors from features.audio.speech_to_text_async.speech_to_text_async_args import ( data_path as audio_data_path, diff --git a/edenai_apis/utils/combine_enums.py b/edenai_apis/utils/combine_enums.py new file mode 100644 index 00000000..b20724d9 --- /dev/null +++ b/edenai_apis/utils/combine_enums.py @@ -0,0 +1,8 @@ +from typing import Type +from enum import Enum + +def combine_enums(name: str, *enum_classes: Type[Enum]) -> Enum: + combined = {} + for enum_class in enum_classes: + combined.update(enum_class.__members__) + return Enum(name, combined) diff --git a/requirements.in b/requirements.in index d0754094..44efecae 100644 --- a/requirements.in +++ b/requirements.in @@ -5,7 +5,7 @@ pandas pdf2image pillow pycountry -pydantic +pydantic==2.7.4 pylint mypy pydub @@ -63,7 +63,7 @@ azure-core aleph_alpha_client==6.0.0 # openai -openai==1.46.0 +openai==1.55.3 python-dotenv diff --git a/requirements.txt b/requirements.txt index 35dcab9f..695956f8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -38,17 +38,10 @@ astroid==3.1.0 attrs==23.2.0 # via # aiohttp - # automat # jsonlines # jsonschema # referencing # sagemaker - # service-identity - # twisted -autobahn==23.6.2 - # via watson-developer-cloud -automat==22.10.0 - # via twisted azure-ai-formrecognizer==3.3.0 # via -r requirements.in azure-cognitiveservices-speech==1.36.0 @@ -97,20 +90,15 @@ clarifai-grpc==9.10.7 # via -r requirements.in cloudpickle==2.2.1 # via sagemaker -constantly==23.10.4 - # via twisted contextlib2==21.6.0 # via schema coverage[toml]==7.4.4 # via pytest-cov cryptography==42.0.5 # via - # autobahn # azure-identity # msal # pyjwt - # pyopenssl - # service-identity dill==0.3.8 # via # multiprocess @@ -242,15 +230,10 @@ httpx==0.27.0 # openai huggingface-hub==0.22.2 # via tokenizers -hyperlink==21.0.0 - # via - # autobahn - # twisted idna==3.6 # via # anyio # httpx - # hyperlink # requests # yarl imagesize==1.4.1 @@ -259,8 +242,6 @@ importlib-metadata==6.11.0 # via sagemaker importlib-resources==6.4.0 # via python-liquid -incremental==22.10.0 - # via twisted iniconfig==2.0.0 # via pytest isodate==0.6.1 @@ -324,7 +305,7 @@ numpy==1.26.4 # shapely oauthlib==3.2.2 # via requests-oauthlib -openai==1.46.0 +openai==1.55.3 # via -r requirements.in packaging==24.0 # via @@ -403,23 +384,20 @@ pyasn1==0.6.0 # via # pyasn1-modules # rsa - # service-identity pyasn1-modules==0.4.0 - # via - # google-auth - # service-identity + # via google-auth pycares==4.4.0 # via aiodns pycountry==23.12.11 # via -r requirements.in pycparser==2.22 # via cffi -pydantic==2.6.4 +pydantic==2.7.4 # via # -r requirements.in # anthropic-bedrock # openai -pydantic-core==2.16.3 +pydantic-core==2.18.4 # via pydantic pydub==0.25.1 # via -r requirements.in @@ -429,17 +407,14 @@ pygments==2.17.2 # sphinx pyjwt[crypto]==2.8.0 # via - # ibm-cloud-sdk-core - # modernmt # msal + # pyjwt pylint==3.1.0 # via -r requirements.in pymupdf==1.24.8 # via -r requirements.in pymupdfb==1.24.8 # via pymupdf -pyopenssl==24.1.0 - # via watson-developer-cloud pyparsing==3.1.2 # via httplib2 pypdf==4.2.0 @@ -463,11 +438,8 @@ python-dateutil==2.9.0.post0 # via # botocore # google-cloud-bigquery - # ibm-cloud-sdk-core - # ibm-watson # pandas # python-liquid - # watson-developer-cloud python-dotenv==1.0.1 # via -r requirements.in python-liquid==1.12.1 @@ -496,17 +468,13 @@ requests==2.31.0 # google-cloud-bigquery # google-cloud-storage # huggingface-hub - # ibm-cloud-sdk-core - # ibm-watson # lettria - # modernmt # msal # msrest # requests-oauthlib # responses # sagemaker # sphinx - # watson-developer-cloud requests-oauthlib==2.0.0 # via msrest responses==0.24.1 @@ -523,13 +491,10 @@ sagemaker==2.214.3 # via -r requirements.in schema==0.7.5 # via sagemaker -service-identity==24.1.0 - # via watson-developer-cloud shapely==2.0.3 # via google-cloud-aiplatform six==1.16.0 # via - # automat # azure-core # google-pasta # isodate @@ -579,10 +544,6 @@ tqdm==4.66.2 # huggingface-hub # openai # sagemaker -twisted==24.3.0 - # via watson-developer-cloud -txaio==23.1.1 - # via autobahn typing-extensions==4.11.0 # via # aleph-alpha-client @@ -595,7 +556,6 @@ typing-extensions==4.11.0 # pydantic # pydantic-core # python-liquid - # twisted tzdata==2024.1 # via pandas uritemplate==4.1.1 @@ -605,20 +565,13 @@ urllib3==2.2.1 # aleph-alpha-client # botocore # docker - # ibm-cloud-sdk-core # requests # responses # sagemaker -watson-developer-cloud==2.0.0 - # via -r requirements.in -websocket-client==1.7.0 - # via ibm-watson yarl==1.9.4 # via aiohttp zipp==3.18.1 # via importlib-metadata -zope-interface==6.2 - # via twisted # The following packages are considered to be unsafe in a requirements file: # setuptools