diff --git a/edenai_apis/apis/google/google_text_api.py b/edenai_apis/apis/google/google_text_api.py index f675a52c..85bf1e9c 100644 --- a/edenai_apis/apis/google/google_text_api.py +++ b/edenai_apis/apis/google/google_text_api.py @@ -19,7 +19,7 @@ CodeGenerationDataClass, GenerationDataClass, ) -from edenai_apis.features.text.chat.chat_dataclass import StreamChat +from edenai_apis.features.text.chat.chat_dataclass import StreamChat, ChatStreamResponse from edenai_apis.features.text.embeddings.embeddings_dataclass import ( EmbeddingDataClass, EmbeddingsDataClass, @@ -361,7 +361,9 @@ def text__chat( except Exception as exc: raise ProviderException(str(exc)) - stream = (res.text for res in responses) + stream = (ChatStreamResponse(text=res.text, blocked = res.is_blocked, provider="google") + for res in responses) + return ResponseType[StreamChat]( original_response=None, standardized_response=StreamChat(stream=stream) diff --git a/edenai_apis/apis/openai/openai_text_api.py b/edenai_apis/apis/openai/openai_text_api.py index 41c3c2c8..ec74f23c 100644 --- a/edenai_apis/apis/openai/openai_text_api.py +++ b/edenai_apis/apis/openai/openai_text_api.py @@ -13,7 +13,7 @@ ) from edenai_apis.features.text.anonymization.category import CategoryType from edenai_apis.features.text.chat import ChatDataClass, ChatMessageDataClass -from edenai_apis.features.text.chat.chat_dataclass import StreamChat +from edenai_apis.features.text.chat.chat_dataclass import StreamChat, ChatStreamResponse from edenai_apis.features.text.code_generation.code_generation_dataclass import ( CodeGenerationDataClass, ) @@ -736,9 +736,12 @@ def text__chat( standardized_response=standardized_response, ) else: - stream = ( - chunk["choices"][0]["delta"].get("content", "") for chunk in response - ) + stream = (ChatStreamResponse( + text = chunk["choices"][0]["delta"].get("content", ""), + blocked = not chunk["choices"][0].get("finish_reason") in (None, "stop"), + provider = "openai" + ) for chunk in response) + return ResponseType[StreamChat]( original_response=None, standardized_response=StreamChat(stream=stream) ) diff --git a/edenai_apis/apis/replicate/replicate_api.py b/edenai_apis/apis/replicate/replicate_api.py index 5bbed0aa..c0e12fe4 100644 --- a/edenai_apis/apis/replicate/replicate_api.py +++ b/edenai_apis/apis/replicate/replicate_api.py @@ -13,7 +13,8 @@ ChatDataClass, ChatMessageDataClass, ) -from edenai_apis.features.text.chat.chat_dataclass import StreamChat +from edenai_apis.features.provider.provider_interface import ProviderInterface +from edenai_apis.features.text.chat.chat_dataclass import StreamChat, ChatStreamResponse from edenai_apis.loaders.loaders import load_provider, ProviderDataEnum from edenai_apis.utils.exception import ProviderException from edenai_apis.utils.types import ResponseType @@ -33,7 +34,7 @@ def __init__(self, api_keys: Dict = {}): "Authorization": f"Token {api_settings['api_key']}", } self.base_url = "https://api.replicate.com/v1" - + def __get_stream_response(self, url: str) -> Generator: headers = {**self.headers, "Accept": "text/event-stream"} response = requests.get(url, headers=headers, stream=True) @@ -43,12 +44,16 @@ def __get_stream_response(self, url: str) -> Generator: response.close() break elif last_chunk == b"event: error" and chunk.startswith(b"data: "): - raise ProviderException("ERROR WHILE STREAMING") + yield ChatStreamResponse(text = "[ERROR]", + blocked = True, + provider = self.provider_name) elif chunk.startswith(b"data: "): if last_chunk == b"data: " and chunk == b"data: ": - yield "\n" + yield ChatStreamResponse(text = "\n", blocked = False, provider = self.provider_name) else: - yield chunk.decode("utf-8").replace("data: ", "") + yield ChatStreamResponse(text = chunk.decode("utf-8").replace("data: ", ""), + blocked = False, + provider = self.provider_name) last_chunk = chunk @overload diff --git a/edenai_apis/features/text/chat/__init__.py b/edenai_apis/features/text/chat/__init__.py index 210f0f7e..d6e0ef2a 100644 --- a/edenai_apis/features/text/chat/__init__.py +++ b/edenai_apis/features/text/chat/__init__.py @@ -1,2 +1,2 @@ from .chat_args import chat_arguments -from .chat_dataclass import ChatDataClass, ChatMessageDataClass, StreamChat +from .chat_dataclass import ChatDataClass, ChatMessageDataClass, StreamChat, ChatStreamResponse diff --git a/edenai_apis/features/text/chat/chat_dataclass.py b/edenai_apis/features/text/chat/chat_dataclass.py index 5678c36f..010c8b00 100644 --- a/edenai_apis/features/text/chat/chat_dataclass.py +++ b/edenai_apis/features/text/chat/chat_dataclass.py @@ -16,6 +16,10 @@ class ChatDataClass(BaseModel): def direct_response(api_response: Dict): return api_response["generated_text"] - +class ChatStreamResponse(BaseModel): + text: str + blocked: bool + provider: str + class StreamChat(BaseModel): - stream: Generator[str, None, None] + stream: Generator[ChatStreamResponse, None, None] diff --git a/edenai_apis/tests/features/test_text_stream.py b/edenai_apis/tests/features/test_text_stream.py index 5e39bcdb..1f9e4a0f 100644 --- a/edenai_apis/tests/features/test_text_stream.py +++ b/edenai_apis/tests/features/test_text_stream.py @@ -4,7 +4,7 @@ import pytest from edenai_apis import Text -from edenai_apis.features.text.chat import StreamChat +from edenai_apis.features.text.chat import StreamChat, ChatStreamResponse from edenai_apis.interface import list_providers from edenai_apis.loaders.data_loader import FeatureDataEnum from edenai_apis.loaders.loaders import load_feature @@ -38,4 +38,4 @@ def test_stream(self, provider): assert isinstance(chat_output.standardized_response.stream, Iterator) for chunk in chat_output.standardized_response.stream: - assert isinstance(chunk, str) + assert isinstance(chunk, ChatStreamResponse)