Skip to content

Commit

Permalink
Merge branch 'master' of github.com:edenai/edenai-apis
Browse files Browse the repository at this point in the history
  • Loading branch information
KyrianC committed Nov 20, 2023
2 parents df4deb2 + bd8f68d commit 6940eae
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 16 deletions.
6 changes: 4 additions & 2 deletions edenai_apis/apis/google/google_text_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
CodeGenerationDataClass,
GenerationDataClass,
)
from edenai_apis.features.text.chat.chat_dataclass import StreamChat
from edenai_apis.features.text.chat.chat_dataclass import StreamChat, ChatStreamResponse
from edenai_apis.features.text.embeddings.embeddings_dataclass import (
EmbeddingDataClass,
EmbeddingsDataClass,
Expand Down Expand Up @@ -361,7 +361,9 @@ def text__chat(
except Exception as exc:
raise ProviderException(str(exc))

stream = (res.text for res in responses)
stream = (ChatStreamResponse(text=res.text, blocked = res.is_blocked, provider="google")
for res in responses)

return ResponseType[StreamChat](
original_response=None,
standardized_response=StreamChat(stream=stream)
Expand Down
11 changes: 7 additions & 4 deletions edenai_apis/apis/openai/openai_text_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from edenai_apis.features.text.anonymization.category import CategoryType
from edenai_apis.features.text.chat import ChatDataClass, ChatMessageDataClass
from edenai_apis.features.text.chat.chat_dataclass import StreamChat
from edenai_apis.features.text.chat.chat_dataclass import StreamChat, ChatStreamResponse
from edenai_apis.features.text.code_generation.code_generation_dataclass import (
CodeGenerationDataClass,
)
Expand Down Expand Up @@ -736,9 +736,12 @@ def text__chat(
standardized_response=standardized_response,
)
else:
stream = (
chunk["choices"][0]["delta"].get("content", "") for chunk in response
)
stream = (ChatStreamResponse(
text = chunk["choices"][0]["delta"].get("content", ""),
blocked = not chunk["choices"][0].get("finish_reason") in (None, "stop"),
provider = "openai"
) for chunk in response)

return ResponseType[StreamChat](
original_response=None, standardized_response=StreamChat(stream=stream)
)
Expand Down
15 changes: 10 additions & 5 deletions edenai_apis/apis/replicate/replicate_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
ChatDataClass,
ChatMessageDataClass,
)
from edenai_apis.features.text.chat.chat_dataclass import StreamChat
from edenai_apis.features.provider.provider_interface import ProviderInterface
from edenai_apis.features.text.chat.chat_dataclass import StreamChat, ChatStreamResponse
from edenai_apis.loaders.loaders import load_provider, ProviderDataEnum
from edenai_apis.utils.exception import ProviderException
from edenai_apis.utils.types import ResponseType
Expand All @@ -33,7 +34,7 @@ def __init__(self, api_keys: Dict = {}):
"Authorization": f"Token {api_settings['api_key']}",
}
self.base_url = "https://api.replicate.com/v1"

def __get_stream_response(self, url: str) -> Generator:
headers = {**self.headers, "Accept": "text/event-stream"}
response = requests.get(url, headers=headers, stream=True)
Expand All @@ -43,12 +44,16 @@ def __get_stream_response(self, url: str) -> Generator:
response.close()
break
elif last_chunk == b"event: error" and chunk.startswith(b"data: "):
raise ProviderException("ERROR WHILE STREAMING")
yield ChatStreamResponse(text = "[ERROR]",
blocked = True,
provider = self.provider_name)
elif chunk.startswith(b"data: "):
if last_chunk == b"data: " and chunk == b"data: ":
yield "\n"
yield ChatStreamResponse(text = "\n", blocked = False, provider = self.provider_name)
else:
yield chunk.decode("utf-8").replace("data: ", "")
yield ChatStreamResponse(text = chunk.decode("utf-8").replace("data: ", ""),
blocked = False,
provider = self.provider_name)
last_chunk = chunk

@overload
Expand Down
2 changes: 1 addition & 1 deletion edenai_apis/features/text/chat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .chat_args import chat_arguments
from .chat_dataclass import ChatDataClass, ChatMessageDataClass, StreamChat
from .chat_dataclass import ChatDataClass, ChatMessageDataClass, StreamChat, ChatStreamResponse
8 changes: 6 additions & 2 deletions edenai_apis/features/text/chat/chat_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ class ChatDataClass(BaseModel):
def direct_response(api_response: Dict):
return api_response["generated_text"]


class ChatStreamResponse(BaseModel):
text: str
blocked: bool
provider: str

class StreamChat(BaseModel):
stream: Generator[str, None, None]
stream: Generator[ChatStreamResponse, None, None]
4 changes: 2 additions & 2 deletions edenai_apis/tests/features/test_text_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from edenai_apis import Text
from edenai_apis.features.text.chat import StreamChat
from edenai_apis.features.text.chat import StreamChat, ChatStreamResponse
from edenai_apis.interface import list_providers
from edenai_apis.loaders.data_loader import FeatureDataEnum
from edenai_apis.loaders.loaders import load_feature
Expand Down Expand Up @@ -38,4 +38,4 @@ def test_stream(self, provider):
assert isinstance(chat_output.standardized_response.stream, Iterator)

for chunk in chat_output.standardized_response.stream:
assert isinstance(chunk, str)
assert isinstance(chunk, ChatStreamResponse)

0 comments on commit 6940eae

Please sign in to comment.