From 8b5f9381191ae4916b4b1ef5d4091e6ff452b496 Mon Sep 17 00:00:00 2001 From: michaeldecent2 <111002205+MichaelDecent@users.noreply.github.com> Date: Mon, 11 Nov 2024 12:00:09 +0100 Subject: [PATCH] swarm - Refactor GroqModel to use httpx for API requests --- .../swarmauri/llms/concrete/GroqModel.py | 156 +++++----- .../swarmauri/llms/concrete/GroqToolModel.py | 287 ++++++++---------- .../llms/concrete/GroqVisionModel.py | 156 +++++----- .../unit/llms/GroqVisionModel_unit_test.py | 4 +- 4 files changed, 277 insertions(+), 326 deletions(-) diff --git a/pkgs/swarmauri/swarmauri/llms/concrete/GroqModel.py b/pkgs/swarmauri/swarmauri/llms/concrete/GroqModel.py index f2abbe2ac..785c26c63 100644 --- a/pkgs/swarmauri/swarmauri/llms/concrete/GroqModel.py +++ b/pkgs/swarmauri/swarmauri/llms/concrete/GroqModel.py @@ -2,7 +2,6 @@ import json from pydantic import PrivateAttr import httpx -import requests from swarmauri.conversations.concrete.Conversation import Conversation from typing import List, Optional, Dict, Literal, Any, AsyncGenerator, Generator @@ -50,7 +49,26 @@ class GroqModel(LLMBase): ] name: str = "gemma-7b-it" type: Literal["GroqModel"] = "GroqModel" - _api_url: str = PrivateAttr("https://api.groq.com/openai/v1/chat/completions") + _client: httpx.Client = PrivateAttr(default=None) + _async_client: httpx.AsyncClient = PrivateAttr(default=None) + _BASE_URL: str = PrivateAttr(default="https://api.groq.com/openai/v1/chat/completions") + + def __init__(self, **data): + """ + Initialize the GroqAIAudio class with the provided data. + + Args: + **data: Arbitrary keyword arguments containing initialization data. + """ + super().__init__(**data) + self._client = httpx.Client( + headers={"Authorization": f"Bearer {self.api_key}"}, + base_url=self._BASE_URL, + ) + self._async_client = httpx.AsyncClient( + headers={"Authorization": f"Bearer {self.api_key}"}, + base_url=self._BASE_URL, + ) def _format_messages( self, @@ -93,24 +111,6 @@ def _prepare_usage_data(self, usage_data) -> UsageData: """ return UsageData.model_validate(usage_data) - def _make_request(self, data: Dict[str, Any]) -> Dict[str, Any]: - """ - Sends a synchronous HTTP POST request to the API and retrieves the response. - - Args: - data (dict): Payload data to be sent in the API request. - - Returns: - dict: Parsed JSON response from the API. - """ - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", - } - response = requests.post(self._api_url, headers=headers, json=data) - response.raise_for_status() # Raise an error for HTTP issues - return response.json() - def predict( self, conversation: Conversation, @@ -135,7 +135,7 @@ def predict( Conversation: Updated conversation with the model's response. """ formatted_messages = self._format_messages(conversation.history) - data = { + payload = { "model": self.name, "messages": formatted_messages, "temperature": temperature, @@ -144,11 +144,16 @@ def predict( "stop": stop or [], } if enable_json: - data["response_format"] = "json_object" + payload["response_format"] = "json_object" - result = self._make_request(data) - message_content = result["choices"][0]["message"]["content"] - usage_data = result.get("usage", {}) + response = self._client.post(self._BASE_URL, json=payload) + + response.raise_for_status() + + response_data = response.json() + + message_content = response_data["choices"][0]["message"]["content"] + usage_data = response_data.get("usage", {}) usage = self._prepare_usage_data(usage_data) conversation.add_message(AgentMessage(content=message_content, usage=usage)) @@ -178,7 +183,7 @@ async def apredict( Conversation: Updated conversation with the model's response. """ formatted_messages = self._format_messages(conversation.history) - data = { + payload = { "model": self.name, "messages": formatted_messages, "temperature": temperature, @@ -187,12 +192,15 @@ async def apredict( "stop": stop or [], } if enable_json: - data["response_format"] = "json_object" + payload["response_format"] = "json_object" - # Use asyncio's to_thread to call synchronous code in an async context - result = await asyncio.to_thread(self._make_request, data) - message_content = result["choices"][0]["message"]["content"] - usage_data = result.get("usage", {}) + response = await self._async_client.post(self._BASE_URL, json=payload) + response.raise_for_status() + + response_data = response.json() + + message_content = response_data["choices"][0]["message"]["content"] + usage_data = response_data.get("usage", {}) usage = self._prepare_usage_data(usage_data) conversation.add_message(AgentMessage(content=message_content, usage=usage)) @@ -223,7 +231,7 @@ def stream( """ formatted_messages = self._format_messages(conversation.history) - data = { + payload = { "model": self.name, "messages": formatted_messages, "temperature": temperature, @@ -233,31 +241,26 @@ def stream( "stop": stop or [], } if enable_json: - data["response_format"] = "json_object" + payload["response_format"] = "json_object" - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", - } + response = self._client.post(self._BASE_URL, json=payload) - with requests.post( - self._api_url, headers=headers, json=data, stream=True - ) as response: - response.raise_for_status() - message_content = "" - for line in response.iter_lines(decode_unicode=True): - json_str = line.replace('data: ', '') - try: - if json_str: - chunk = json.loads(json_str) - if chunk["choices"][0]["delta"]: - delta = chunk["choices"][0]["delta"]["content"] - message_content += delta - yield delta - except json.JSONDecodeError: - pass - - conversation.add_message(AgentMessage(content=message_content)) + response.raise_for_status() + + message_content = "" + for line in response.iter_lines(): + json_str = line.replace('data: ', '') + try: + if json_str: + chunk = json.loads(json_str) + if chunk["choices"][0]["delta"]: + delta = chunk["choices"][0]["delta"]["content"] + message_content += delta + yield delta + except json.JSONDecodeError: + pass + + conversation.add_message(AgentMessage(content=message_content)) async def astream( self, @@ -284,7 +287,7 @@ async def astream( """ formatted_messages = self._format_messages(conversation.history) - data = { + payload = { "model": self.name, "messages": formatted_messages, "temperature": temperature, @@ -294,29 +297,24 @@ async def astream( "stop": stop or [], } if enable_json: - data["response_format"] = "json_object" - - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", - } - - async with httpx.AsyncClient() as client: - response = await client.post(self._api_url, headers=headers, json=data, timeout=None) - - response.raise_for_status() - message_content = "" - async for line in response.aiter_lines(): - json_str = line.replace('data: ', '') - try: - if json_str: - chunk = json.loads(json_str) - if chunk["choices"][0]["delta"]: - delta = chunk["choices"][0]["delta"]["content"] - message_content += delta - yield delta - except json.JSONDecodeError: - pass + payload["response_format"] = "json_object" + + response = await self._async_client.post(self._BASE_URL, json=payload) + + response.raise_for_status() + message_content = "" + + async for line in response.aiter_lines(): + json_str = line.replace('data: ', '') + try: + if json_str: + chunk = json.loads(json_str) + if chunk["choices"][0]["delta"]: + delta = chunk["choices"][0]["delta"]["content"] + message_content += delta + yield delta + except json.JSONDecodeError: + pass conversation.add_message(AgentMessage(content=message_content)) diff --git a/pkgs/swarmauri/swarmauri/llms/concrete/GroqToolModel.py b/pkgs/swarmauri/swarmauri/llms/concrete/GroqToolModel.py index d6c5b0a64..f49a22736 100644 --- a/pkgs/swarmauri/swarmauri/llms/concrete/GroqToolModel.py +++ b/pkgs/swarmauri/swarmauri/llms/concrete/GroqToolModel.py @@ -2,18 +2,15 @@ import json from typing import AsyncIterator, Iterator, List, Literal, Dict, Any -import logging import httpx from pydantic import PrivateAttr -import requests from swarmauri.conversations.concrete import Conversation from swarmauri_core.typing import SubclassUnion from swarmauri.messages.base.MessageBase import MessageBase from swarmauri.messages.concrete.AgentMessage import AgentMessage -from swarmauri.messages.concrete.FunctionMessage import FunctionMessage from swarmauri.llms.base.LLMBase import LLMBase from swarmauri.schema_converters.concrete.GroqSchemaConverter import ( GroqSchemaConverter, @@ -52,23 +49,28 @@ class GroqToolModel(LLMBase): ] name: str = "llama3-groq-70b-8192-tool-use-preview" type: Literal["GroqToolModel"] = "GroqToolModel" - _headers: Dict[str, str] = PrivateAttr(default=None) - _api_url: str = PrivateAttr( + _client: httpx.Client = PrivateAttr(default=None) + _async_client: httpx.AsyncClient = PrivateAttr(default=None) + _BASE_URL: str = PrivateAttr( default="https://api.groq.com/openai/v1/chat/completions" ) - def __init__(self, **data) -> None: + def __init__(self, **data): """ - Initializes the GroqToolModel instance, setting up headers for API requests. + Initialize the GroqAIAudio class with the provided data. - Parameters: - **data: Arbitrary keyword arguments for initialization. + Args: + **data: Arbitrary keyword arguments containing initialization data. """ super().__init__(**data) - self._headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", - } + self._client = httpx.Client( + headers={"Authorization": f"Bearer {self.api_key}"}, + base_url=self._BASE_URL, + ) + self._async_client = httpx.AsyncClient( + headers={"Authorization": f"Bearer {self.api_key}"}, + base_url=self._BASE_URL, + ) def _schema_convert_tools(self, tools) -> List[Dict[str, Any]]: """ @@ -82,6 +84,25 @@ def _schema_convert_tools(self, tools) -> List[Dict[str, Any]]: """ return [GroqSchemaConverter().convert(tools[tool]) for tool in tools] + def _process_tool_calls(self, tool_calls, toolkit, messages) -> List[MessageBase]: + if tool_calls: + for tool_call in tool_calls: + func_name = tool_call["function"]["name"] + + func_call = toolkit.get_tool_by_name(func_name) + func_args = json.loads(tool_call["function"]["arguments"]) + func_result = func_call(**func_args) + + messages.append( + { + "tool_call_id": tool_call["id"], + "role": "tool", + "name": func_name, + "content": json.dumps(func_result), + } + ) + return messages + def _format_messages( self, messages: List[SubclassUnion[MessageBase]] ) -> List[Dict[str, str]]: @@ -136,33 +157,29 @@ def predict( "tool_choice": tool_choice, } - response = requests.post(self._api_url, headers=self._headers, json=payload) + response = self._client.post(self._BASE_URL, json=payload) response.raise_for_status() tool_response = response.json() - if "content" in tool_response["choices"][0]["message"]: - agent_message = AgentMessage( - content=tool_response["choices"][0]["message"]["content"] - ) - conversation.add_message(agent_message) - + messages = [formatted_messages[-1], tool_response["choices"][0]["message"]] tool_calls = tool_response["choices"][0]["message"].get("tool_calls", []) - if tool_calls: - for tool_call in tool_calls: - func_name = tool_call["function"]["name"] - func_call = toolkit.get_tool_by_name(func_name) - func_args = json.loads(tool_call["function"]["arguments"]) - func_result = func_call(**func_args) + messages = self._process_tool_calls(tool_calls, toolkit, messages) - func_message = FunctionMessage( - content=json.dumps(func_result), - name=func_name, - tool_call_id=tool_call["id"], - ) - conversation.add_message(func_message) + payload["messages"] = messages + payload.pop("tools", None) + payload.pop("tool_choice", None) + + response = self._client.post(self._BASE_URL, json=payload) + response.raise_for_status() + + agent_response = response.json() + agent_message = AgentMessage( + content=agent_response["choices"][0]["message"]["content"] + ) + conversation.add_message(agent_message) return conversation async def apredict( @@ -200,37 +217,29 @@ async def apredict( "tool_choice": tool_choice, } - async with httpx.AsyncClient() as client: - response = await client.post( - self._api_url, headers=self._headers, json=payload - ) - response.raise_for_status() + response = await self._async_client.post(self._BASE_URL, json=payload) + response.raise_for_status() tool_response = response.json() - if "content" in tool_response["choices"][0]["message"]: - agent_message = AgentMessage( - content=tool_response["choices"][0]["message"]["content"] - ) - conversation.add_message(agent_message) - + messages = [formatted_messages[-1], tool_response["choices"][0]["message"]] tool_calls = tool_response["choices"][0]["message"].get("tool_calls", []) - if tool_calls: - for tool_call in tool_calls: - func_name = tool_call["function"]["name"] - func_call = toolkit.get_tool_by_name(func_name) - func_args = json.loads(tool_call["function"]["arguments"]) - func_result = func_call(**func_args) + messages = self._process_tool_calls(tool_calls, toolkit, messages) - func_message = FunctionMessage( - content=json.dumps(func_result), - name=func_name, - tool_call_id=tool_call["id"], - ) - conversation.add_message(func_message) + payload["messages"] = messages + payload.pop("tools", None) + payload.pop("tool_choice", None) + + response = await self._async_client.post(self._BASE_URL, json=payload) + response.raise_for_status() - logging.info(conversation.history) + agent_response = response.json() + + agent_message = AgentMessage( + content=agent_response["choices"][0]["message"]["content"] + ) + conversation.add_message(agent_message) return conversation def stream( @@ -257,7 +266,7 @@ def stream( formatted_messages = self._format_messages(conversation.history) - request_data = { + payload = { "model": self.name, "messages": formatted_messages, "temperature": temperature, @@ -266,65 +275,38 @@ def stream( "tool_choice": tool_choice or "auto", } - # Initial tool response - response = requests.post( - self._api_url, - headers=self._headers, - json=request_data, - ) + response = self._client.post(self._BASE_URL, json=payload) response.raise_for_status() - tool_response = response.json() - logging.info(tool_response) - if "content" in tool_response["choices"][0]["message"]: - agent_message = AgentMessage( - content=tool_response["choices"][0]["message"]["content"] - ) - conversation.add_message(agent_message) + tool_response = response.json() + messages = [formatted_messages[-1], tool_response["choices"][0]["message"]] tool_calls = tool_response["choices"][0]["message"].get("tool_calls", []) - if tool_calls: - for tool_call in tool_calls: - func_name = tool_call["function"]["name"] - func_call = toolkit.get_tool_by_name(func_name) - func_args = json.loads(tool_call["function"]["arguments"]) - func_result = func_call(**func_args) + messages = self._process_tool_calls(tool_calls, toolkit, messages) - func_message = FunctionMessage( - content=json.dumps(func_result), - name=func_name, - tool_call_id=tool_call["id"], - ) - conversation.add_message(func_message) + payload["messages"] = messages + payload["stream"] = True + payload.pop("tools", None) + payload.pop("tool_choice", None) - formatted_messages = self._format_messages(conversation.history) - request_data["messages"] = formatted_messages - request_data["stream"] = True - request_data.pop("tools", None) - request_data.pop("tool_choice", None) - - with requests.post( - self._api_url, - headers=self._headers, - json=request_data, - ) as response: - response.raise_for_status() - message_content = "" - - for line in response.iter_lines(decode_unicode=True): - json_str = line.replace('data: ', '') - try: - if json_str: - chunk = json.loads(json_str) - if chunk["choices"][0]["delta"]: - delta = chunk["choices"][0]["delta"]["content"] - message_content += delta - yield delta - except json.JSONDecodeError: - pass - - conversation.add_message(AgentMessage(content=message_content)) + response = self._client.post(self._BASE_URL, json=payload) + response.raise_for_status() + message_content = "" + + for line in response.iter_lines(): + json_str = line.replace("data: ", "") + try: + if json_str: + chunk = json.loads(json_str) + if chunk["choices"][0]["delta"]: + delta = chunk["choices"][0]["delta"]["content"] + message_content += delta + yield delta + except json.JSONDecodeError: + pass + + conversation.add_message(AgentMessage(content=message_content)) async def astream( self, @@ -348,8 +330,8 @@ async def astream( AsyncIterator[str]: Streamed response content. """ formatted_messages = self._format_messages(conversation.history) - - request_data = { + + payload = { "model": self.name, "messages": formatted_messages, "temperature": temperature, @@ -358,61 +340,36 @@ async def astream( "tool_choice": tool_choice or "auto", } - async with httpx.AsyncClient() as client: - response = await client.post( - self._api_url, - headers=self._headers, - json=request_data - ) - response.raise_for_status() - tool_response = response.json() - logging.info(tool_response) + response = self._client.post(self._BASE_URL, json=payload) + response.raise_for_status() - if "content" in tool_response["choices"][0]["message"]: - agent_message = AgentMessage( - content=tool_response["choices"][0]["message"]["content"] - ) - conversation.add_message(agent_message) - - tool_calls = tool_response["choices"][0]["message"].get("tool_calls", []) - if tool_calls: - for tool_call in tool_calls: - func_name = tool_call["function"]["name"] - - func_call = toolkit.get_tool_by_name(func_name) - func_args = json.loads(tool_call["function"]["arguments"]) - func_result = func_call(**func_args) - - func_message = FunctionMessage( - content=json.dumps(func_result), - name=func_name, - tool_call_id=tool_call["id"], - ) - conversation.add_message(func_message) - - formatted_messages = self._format_messages(conversation.history) - request_data["messages"] = formatted_messages - request_data["stream"] = True - request_data.pop("tools", None) - request_data.pop("tool_choice", None) - - async with httpx.AsyncClient() as client: - response = await client.post(self._api_url, headers=self._headers, json=request_data) - - response.raise_for_status() - message_content = "" - - async for line in response.aiter_lines(): - json_str = line.replace('data: ', '') - try: - if json_str: - chunk = json.loads(json_str) - if chunk["choices"][0]["delta"]: - delta = chunk["choices"][0]["delta"]["content"] - message_content += delta - yield delta - except json.JSONDecodeError: - pass + tool_response = response.json() + + messages = [formatted_messages[-1], tool_response["choices"][0]["message"]] + tool_calls = tool_response["choices"][0]["message"].get("tool_calls", []) + + messages = self._process_tool_calls(tool_calls, toolkit, messages) + + payload["messages"] = messages + payload["stream"] = True + payload.pop("tools", None) + payload.pop("tool_choice", None) + + response = self._client.post(self._BASE_URL, json=payload) + response.raise_for_status() + message_content = "" + + async for line in response.aiter_lines(): + json_str = line.replace("data: ", "") + try: + if json_str: + chunk = json.loads(json_str) + if chunk["choices"][0]["delta"]: + delta = chunk["choices"][0]["delta"]["content"] + message_content += delta + yield delta + except json.JSONDecodeError: + pass conversation.add_message(AgentMessage(content=message_content)) def batch( diff --git a/pkgs/swarmauri/swarmauri/llms/concrete/GroqVisionModel.py b/pkgs/swarmauri/swarmauri/llms/concrete/GroqVisionModel.py index 390029153..58892d7e5 100644 --- a/pkgs/swarmauri/swarmauri/llms/concrete/GroqVisionModel.py +++ b/pkgs/swarmauri/swarmauri/llms/concrete/GroqVisionModel.py @@ -2,7 +2,6 @@ import json from pydantic import PrivateAttr import httpx -import requests from swarmauri.conversations.concrete.Conversation import Conversation from typing import List, Optional, Dict, Literal, Any, AsyncGenerator, Generator @@ -36,7 +35,26 @@ class GroqVisionModel(LLMBase): ] name: str = "llama-3.2-11b-vision-preview" type: Literal["GroqVisionModel"] = "GroqVisionModel" - _api_url: str = PrivateAttr("https://api.groq.com/openai/v1/chat/completions") + _client: httpx.Client = PrivateAttr(default=None) + _async_client: httpx.AsyncClient = PrivateAttr(default=None) + _BASE_URL: str = PrivateAttr(default="https://api.groq.com/openai/v1/chat/completions") + + def __init__(self, **data): + """ + Initialize the GroqAIAudio class with the provided data. + + Args: + **data: Arbitrary keyword arguments containing initialization data. + """ + super().__init__(**data) + self._client = httpx.Client( + headers={"Authorization": f"Bearer {self.api_key}"}, + base_url=self._BASE_URL, + ) + self._async_client = httpx.AsyncClient( + headers={"Authorization": f"Bearer {self.api_key}"}, + base_url=self._BASE_URL, + ) def _format_messages( self, @@ -79,24 +97,6 @@ def _prepare_usage_data(self, usage_data) -> UsageData: """ return UsageData.model_validate(usage_data) - def _make_request(self, data: Dict[str, Any]) -> Dict[str, Any]: - """ - Sends a synchronous HTTP POST request to the API and retrieves the response. - - Args: - data (dict): Payload data to be sent in the API request. - - Returns: - dict: Parsed JSON response from the API. - """ - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", - } - response = requests.post(self._api_url, headers=headers, json=data) - response.raise_for_status() # Raise an error for HTTP issues - return response.json() - def predict( self, conversation: Conversation, @@ -121,7 +121,7 @@ def predict( Conversation: Updated conversation with the model's response. """ formatted_messages = self._format_messages(conversation.history) - data = { + payload = { "model": self.name, "messages": formatted_messages, "temperature": temperature, @@ -130,11 +130,16 @@ def predict( "stop": stop or [], } if enable_json: - data["response_format"] = "json_object" + payload["response_format"] = "json_object" + + response = self._client.post(self._BASE_URL, json=payload) + + response.raise_for_status() - result = self._make_request(data) - message_content = result["choices"][0]["message"]["content"] - usage_data = result.get("usage", {}) + response_data = response.json() + + message_content = response_data["choices"][0]["message"]["content"] + usage_data = response_data.get("usage", {}) usage = self._prepare_usage_data(usage_data) conversation.add_message(AgentMessage(content=message_content, usage=usage)) @@ -164,7 +169,7 @@ async def apredict( Conversation: Updated conversation with the model's response. """ formatted_messages = self._format_messages(conversation.history) - data = { + payload = { "model": self.name, "messages": formatted_messages, "temperature": temperature, @@ -173,12 +178,15 @@ async def apredict( "stop": stop or [], } if enable_json: - data["response_format"] = "json_object" + payload["response_format"] = "json_object" + + response = await self._async_client.post(self._BASE_URL, json=payload) + response.raise_for_status() + + response_data = response.json() - # Use asyncio's to_thread to call synchronous code in an async context - result = await asyncio.to_thread(self._make_request, data) - message_content = result["choices"][0]["message"]["content"] - usage_data = result.get("usage", {}) + message_content = response_data["choices"][0]["message"]["content"] + usage_data = response_data.get("usage", {}) usage = self._prepare_usage_data(usage_data) conversation.add_message(AgentMessage(content=message_content, usage=usage)) @@ -209,7 +217,7 @@ def stream( """ formatted_messages = self._format_messages(conversation.history) - data = { + payload = { "model": self.name, "messages": formatted_messages, "temperature": temperature, @@ -219,31 +227,25 @@ def stream( "stop": stop or [], } if enable_json: - data["response_format"] = "json_object" + payload["response_format"] = "json_object" + + response = self._client.post(self._BASE_URL, json=payload) + + response.raise_for_status() + message_content = "" + for line in response.iter_lines(): + json_str = line.replace('data: ', '') + try: + if json_str: + chunk = json.loads(json_str) + if chunk["choices"][0]["delta"]: + delta = chunk["choices"][0]["delta"]["content"] + message_content += delta + yield delta + except json.JSONDecodeError: + pass - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", - } - - with requests.post( - self._api_url, headers=headers, json=data, stream=True - ) as response: - response.raise_for_status() - message_content = "" - for line in response.iter_lines(decode_unicode=True): - json_str = line.replace('data: ', '') - try: - if json_str: - chunk = json.loads(json_str) - if chunk["choices"][0]["delta"]: - delta = chunk["choices"][0]["delta"]["content"] - message_content += delta - yield delta - except json.JSONDecodeError: - pass - - conversation.add_message(AgentMessage(content=message_content)) + conversation.add_message(AgentMessage(content=message_content)) async def astream( self, @@ -270,7 +272,7 @@ async def astream( """ formatted_messages = self._format_messages(conversation.history) - data = { + payload = { "model": self.name, "messages": formatted_messages, "temperature": temperature, @@ -280,29 +282,23 @@ async def astream( "stop": stop or [], } if enable_json: - data["response_format"] = "json_object" - - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", - } - - async with httpx.AsyncClient() as client: - response = await client.post(self._api_url, headers=headers, json=data, timeout=None) - - response.raise_for_status() - message_content = "" - async for line in response.aiter_lines(): - json_str = line.replace('data: ', '') - try: - if json_str: - chunk = json.loads(json_str) - if chunk["choices"][0]["delta"]: - delta = chunk["choices"][0]["delta"]["content"] - message_content += delta - yield delta - except json.JSONDecodeError: - pass + payload["response_format"] = "json_object" + + response = await self._async_client.post(self._BASE_URL, json=payload) + + response.raise_for_status() + message_content = "" + async for line in response.aiter_lines(): + json_str = line.replace('data: ', '') + try: + if json_str: + chunk = json.loads(json_str) + if chunk["choices"][0]["delta"]: + delta = chunk["choices"][0]["delta"]["content"] + message_content += delta + yield delta + except json.JSONDecodeError: + pass conversation.add_message(AgentMessage(content=message_content)) diff --git a/pkgs/swarmauri/tests/unit/llms/GroqVisionModel_unit_test.py b/pkgs/swarmauri/tests/unit/llms/GroqVisionModel_unit_test.py index fe3331b1d..67f9c4694 100644 --- a/pkgs/swarmauri/tests/unit/llms/GroqVisionModel_unit_test.py +++ b/pkgs/swarmauri/tests/unit/llms/GroqVisionModel_unit_test.py @@ -134,7 +134,7 @@ def test_batch(groq_model, model_name, input_data): @pytest.mark.parametrize("model_name", get_allowed_models()) @pytest.mark.asyncio(loop_scope="session") @pytest.mark.unit -async def test_apredict(groq_model, model_name): +async def test_apredict(groq_model, model_name, input_data): model = groq_model model.name = model_name conversation = Conversation() @@ -151,7 +151,7 @@ async def test_apredict(groq_model, model_name): @pytest.mark.parametrize("model_name", get_allowed_models()) @pytest.mark.asyncio(loop_scope="session") @pytest.mark.unit -async def test_astream(groq_model, model_name): +async def test_astream(groq_model, model_name, input_data): model = groq_model model.name = model_name conversation = Conversation()