Skip to content

Commit

Permalink
[Fix] change alephapha_api to repair the unbound loccal variable
Browse files Browse the repository at this point in the history
  • Loading branch information
coscialp committed Nov 6, 2023
1 parent f14cab5 commit f878a6d
Showing 1 changed file with 50 additions and 35 deletions.
85 changes: 50 additions & 35 deletions edenai_apis/apis/alephalpha/alephalpha_api.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
from typing import Dict, Sequence, Optional

import requests
from aleph_alpha_client import (
Client,
Prompt,
SemanticEmbeddingRequest,
Image,
SemanticRepresentation,
CompletionRequest,
Text,
)

from edenai_apis.features import ProviderInterface, TextInterface, ImageInterface
from edenai_apis.features.image.embeddings import EmbeddingsDataClass, EmbeddingDataClass
from edenai_apis.features.image.embeddings import (
EmbeddingsDataClass,
EmbeddingDataClass,
)
from edenai_apis.features.image.question_answer import QuestionAnswerDataClass
from edenai_apis.features.text import SummarizeDataClass
from edenai_apis.loaders.data_loader import ProviderDataEnum
from edenai_apis.loaders.loaders import load_provider
from edenai_apis.utils.exception import ProviderException
from edenai_apis.utils.types import ResponseType

from aleph_alpha_client import Client, Prompt, SemanticEmbeddingRequest, QaRequest, Image, SemanticRepresentation, \
Document, CompletionRequest, Text


class AlephAlphaApi(ProviderInterface, TextInterface, ImageInterface):
provider_name = "alephalpha"
Expand All @@ -27,23 +36,18 @@ def __init__(self, api_keys: Dict = {}):
self.url_summarise = "https://api.aleph-alpha.com/summarize"

def text__summarize(
self,
text: str,
output_sentences: int,
language: str,
model: str,
self,
text: str,
output_sentences: int,
language: str,
model: str,
) -> ResponseType[SummarizeDataClass]:
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
payload = {
"model": model,
"document": {
"text": text
}
"Authorization": f"Bearer {self.api_key}",
}
payload = {"model": model, "document": {"text": text}}
response = requests.post(url=self.url_summarise, headers=headers, json=payload)
if response.status_code != 200:
raise ProviderException(response.text, code=response.status_code)
Expand All @@ -57,7 +61,11 @@ def text__summarize(
)

def image__embeddings(
self, file: str, model: str, representation: str, file_url: str = "",
self,
file: str,
model: str,
representation: str,
file_url: str = "",
) -> ResponseType[EmbeddingsDataClass]:
if representation == "symmetric":
representation_client = SemanticRepresentation.Symmetric
Expand All @@ -67,47 +75,54 @@ def image__embeddings(
representation_client = SemanticRepresentation.Query
client = Client(self.api_key)
prompt = Prompt.from_image(Image.from_file(file))
request = SemanticEmbeddingRequest(prompt=prompt, representation=representation_client)
request = SemanticEmbeddingRequest(
prompt=prompt, representation=representation_client
)
try:
response = client.semantic_embed(request=request, model=model)
except:
raise ProviderException(response.message)
if response.message:
raise ProviderException(response.message)
except Exception as exc:
raise ProviderException(message=str(exc)) from exc

original_response = response._asdict()
items: Sequence[EmbeddingDataClass] = [EmbeddingDataClass(embedding=response.embedding)]
items: Sequence[EmbeddingDataClass] = [
EmbeddingDataClass(embedding=response.embedding)
]
standardized_response = EmbeddingsDataClass(items=items)
return ResponseType[EmbeddingsDataClass](
original_response=original_response,
standardized_response=standardized_response
standardized_response=standardized_response,
)

def image__question_answer(
self,
file: str,
temperature: float,
max_tokens: int,
file_url: str = "",
model: Optional[str] = None,
question: Optional[str] = None
self,
file: str,
temperature: float,
max_tokens: int,
file_url: str = "",
model: Optional[str] = None,
question: Optional[str] = None,
) -> ResponseType[QuestionAnswerDataClass]:
client = Client(self.api_key)
if question:
prompts = Prompt([Text.from_text(question), Image.from_file(file)])
else:
prompts = Prompt([Image.from_file(file)])
request = CompletionRequest(prompt=prompts, maximum_tokens=max_tokens, temperature=temperature, tokens=True)
request = CompletionRequest(
prompt=prompts,
maximum_tokens=max_tokens,
temperature=temperature,
tokens=True,
)
try:
response = client.complete(request=request, model=model)
except Exception as error:
raise ProviderException(str(error))
raise ProviderException(str(error)) from error
original_response = response._asdict()
answers = []
for answer in response.completions:
answers.append(answer.completion)
standardized_response = QuestionAnswerDataClass(answers=answers)
return ResponseType[QuestionAnswerDataClass](
original_response=original_response,
standardized_response=standardized_response
standardized_response=standardized_response,
)

0 comments on commit f878a6d

Please sign in to comment.