Skip to content

Commit

Permalink
[Fix] update sdk clarifai
Browse files Browse the repository at this point in the history
  • Loading branch information
coscialp committed Nov 30, 2023
1 parent 84f897d commit 317b428
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 65 deletions.
65 changes: 13 additions & 52 deletions edenai_apis/apis/clarifai/clarifai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
LogoBoundingPoly,
LogoVertice,
)
from edenai_apis.features.image.explicit_content.category import CategoryType as CategoryTypeExplicitContent
from edenai_apis.features.image.explicit_content.category import (
CategoryType as CategoryTypeExplicitContent,
)
from edenai_apis.features.image.face_detection.face_detection_dataclass import (
FaceAccessories,
FaceEmotions,
Expand Down Expand Up @@ -72,53 +74,9 @@ def __init__(self, api_keys: Dict = {}) -> None:
def text__generation(
self, text: str, temperature: float, max_tokens: int, model: str
) -> ResponseType[GenerationDataClass]:
text = f"[INST] {text} [/INST]"

channel = ClarifaiChannel.get_grpc_channel()
stub = service_pb2_grpc.V2Stub(channel)

metadata = (("authorization", self.key),)
user_data_object = resources_pb2.UserAppIDSet(
user_id="mistralai", app_id="completion"
)

post_model_outputs_response = stub.PostModelOutputs(
service_pb2.PostModelOutputsRequest(
user_app_id=user_data_object,
model_id=model,
inputs=[
resources_pb2.Input(
data=resources_pb2.Data(text=resources_pb2.Text(raw=text))
)
],
),
metadata=metadata,
)

if post_model_outputs_response.status.code != status_code_pb2.SUCCESS:
raise ProviderException(
post_model_outputs_response.status.description,
code=post_model_outputs_response.status.code,
)

response = MessageToDict(
post_model_outputs_response, preserving_proto_field_name=True
)

output = response.get("outputs", [])
if len(output) == 0:
raise ProviderException(
"Clarifai returned an empty response!",
code=post_model_outputs_response.status.code,
)

original_response = output[0].get("data", {}) or {}

return ResponseType[GenerationDataClass](
original_response=original_response,
standardized_response=GenerationDataClass(
generated_text=(original_response.get("text", {}) or {}).get("raw", "")
),
raise ProviderException(
message="This provider is deprecated. You won't be charged for your call.",
code=500,
)

def text__moderation(
Expand All @@ -132,7 +90,8 @@ def text__moderation(

post_model_outputs_response = stub.PostModelOutputs(
service_pb2.PostModelOutputsRequest(
user_app_id=user_data_object, # The userDataObject is created in the overview and is required when using a PAT
# The userDataObject is created in the overview and is required when using a PAT
user_app_id=user_data_object,
model_id=self.text_moderation_code,
inputs=[
resources_pb2.Input(
Expand Down Expand Up @@ -166,11 +125,11 @@ def text__moderation(
classificator = CategoryType.choose_category_subcategory(concept["name"])
classification.append(
TextModerationItem(
label= concept["name"],
label=concept["name"],
category=classificator["category"],
subcategory=classificator["subcategory"],
likelihood_score=concept["value"],
likelihood=standardized_confidence_score(concept["value"])
likelihood=standardized_confidence_score(concept["value"]),
)
)
standardized_response: ModerationDataClass = ModerationDataClass(
Expand Down Expand Up @@ -299,7 +258,9 @@ def image__explicit_content(
original_response = response.get("outputs", [])[0]["data"]
items = []
for concept in original_response["concepts"]:
classificator = CategoryTypeExplicitContent.choose_category_subcategory(concept["name"])
classificator = CategoryTypeExplicitContent.choose_category_subcategory(
concept["name"]
)
items.append(
ExplicitItem(
label=concept["name"],
Expand Down
2 changes: 1 addition & 1 deletion requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ amazon-textract-response-parser
boto3

#clarifai
clarifai-grpc==9.5.0
clarifai-grpc==9.10.7

#google
google-api-python-client==2.88.0
Expand Down
44 changes: 32 additions & 12 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
#
# This file is autogenerated by pip-compile with Python 3.10
# This file is autogenerated by pip-compile with Python 3.8
# by the following command:
#
# pip-compile --output-file=requirements.txt --resolver=backtracking requirements.in
# pip-compile --output-file=requirements.txt requirements.in
#
affinda==4.7.1
# via -r requirements.in

mypy==1.6.1
mypy-extensions==1.0.0
aiodns==3.0.0
# via aleph-alpha-client
aiohttp==3.8.5
Expand Down Expand Up @@ -87,7 +84,7 @@ charset-normalizer==3.1.0
# via
# aiohttp
# requests
clarifai-grpc==9.5.0
clarifai-grpc==9.10.7
# via -r requirements.in
cloudpickle==2.2.1
# via sagemaker
Expand All @@ -96,7 +93,9 @@ constantly==15.1.0
contextlib2==21.6.0
# via schema
coverage[toml]==7.2.7
# via pytest-cov
# via
# coverage
# pytest-cov
cryptography==41.0.1
# via
# autobahn
Expand Down Expand Up @@ -174,7 +173,7 @@ google-cloud-language==2.10.0
# via -r requirements.in
google-cloud-resource-manager==1.10.4
# via google-cloud-aiplatform
google-cloud-speech==2.21.0
google-cloud-speech==2.20.0
# via -r requirements.in
google-cloud-storage==2.9.0
# via
Expand Down Expand Up @@ -236,7 +235,11 @@ idna==3.4
imagesize==1.4.1
# via sphinx
importlib-metadata==4.13.0
# via sagemaker
# via
# sagemaker
# sphinx
importlib-resources==6.1.1
# via jsonschema
incremental==22.10.0
# via twisted
iniconfig==2.0.0
Expand Down Expand Up @@ -289,6 +292,10 @@ multidict==6.0.4
# yarl
multiprocess==0.70.14
# via pathos
mypy==1.6.1
# via -r requirements.in
mypy-extensions==1.0.0
# via mypy
numpy==1.23.4
# via
# pandas
Expand Down Expand Up @@ -320,6 +327,8 @@ pillow==10.0.0
# -r requirements.in
# aleph-alpha-client
# pdf2image
pkgutil-resolve-name==1.3.10
# via jsonschema
platformdirs==3.8.0
# via
# pylint
Expand Down Expand Up @@ -434,12 +443,16 @@ python-dateutil==2.8.2
# pandas
# python-liquid
# watson-developer-cloud
python-dotenv==1.0.0
# via -r requirements.in
python-liquid==1.10.0
# via aleph-alpha-client
python-magic==0.4.27
# via -r requirements.in
pytz==2023.3
# via pandas
# via
# babel
# pandas
pyyaml==6.0
# via
# huggingface-hub
Expand Down Expand Up @@ -521,6 +534,7 @@ tokenizers==0.14.0
tomli==2.0.1
# via
# coverage
# mypy
# pylint
# pytest
tomlkit==0.11.8
Expand All @@ -537,13 +551,17 @@ txaio==23.1.1
typing-extensions==4.7.1
# via
# aleph-alpha-client
# annotated-types
# astroid
# azure-ai-formrecognizer
# azure-core
# filelock
# huggingface-hub
# mypy
# pydantic
# pydantic-core
# pylint
# pypdf
# python-liquid
# twisted
tzdata==2023.3
Expand All @@ -566,9 +584,11 @@ wrapt==1.15.0
yarl==1.9.2
# via aiohttp
zipp==3.15.0
# via importlib-metadata
# via
# importlib-metadata
# importlib-resources
zope-interface==6.0
# via twisted
python-dotenv

# The following packages are considered to be unsafe in a requirements file:
# setuptools

0 comments on commit 317b428

Please sign in to comment.