Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

swarm - Refactor GroqModel to use httpx for API requests #782

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 77 additions & 79 deletions pkgs/swarmauri/swarmauri/llms/concrete/GroqModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import json
from pydantic import PrivateAttr
import httpx
import requests
from swarmauri.conversations.concrete.Conversation import Conversation
from typing import List, Optional, Dict, Literal, Any, AsyncGenerator, Generator

Expand Down Expand Up @@ -50,7 +49,26 @@ class GroqModel(LLMBase):
]
name: str = "gemma-7b-it"
type: Literal["GroqModel"] = "GroqModel"
_api_url: str = PrivateAttr("https://api.groq.com/openai/v1/chat/completions")
_client: httpx.Client = PrivateAttr(default=None)
_async_client: httpx.AsyncClient = PrivateAttr(default=None)
_BASE_URL: str = PrivateAttr(default="https://api.groq.com/openai/v1/chat/completions")

def __init__(self, **data):
"""
Initialize the GroqAIAudio class with the provided data.

Args:
**data: Arbitrary keyword arguments containing initialization data.
"""
super().__init__(**data)
self._client = httpx.Client(
headers={"Authorization": f"Bearer {self.api_key}"},
base_url=self._BASE_URL,
)
self._async_client = httpx.AsyncClient(
headers={"Authorization": f"Bearer {self.api_key}"},
base_url=self._BASE_URL,
)

def _format_messages(
self,
Expand Down Expand Up @@ -93,24 +111,6 @@ def _prepare_usage_data(self, usage_data) -> UsageData:
"""
return UsageData.model_validate(usage_data)

def _make_request(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Sends a synchronous HTTP POST request to the API and retrieves the response.

Args:
data (dict): Payload data to be sent in the API request.

Returns:
dict: Parsed JSON response from the API.
"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
response = requests.post(self._api_url, headers=headers, json=data)
response.raise_for_status() # Raise an error for HTTP issues
return response.json()

def predict(
self,
conversation: Conversation,
Expand All @@ -135,7 +135,7 @@ def predict(
Conversation: Updated conversation with the model's response.
"""
formatted_messages = self._format_messages(conversation.history)
data = {
payload = {
"model": self.name,
"messages": formatted_messages,
"temperature": temperature,
Expand All @@ -144,11 +144,16 @@ def predict(
"stop": stop or [],
}
if enable_json:
data["response_format"] = "json_object"
payload["response_format"] = "json_object"

result = self._make_request(data)
message_content = result["choices"][0]["message"]["content"]
usage_data = result.get("usage", {})
response = self._client.post(self._BASE_URL, json=payload)

response.raise_for_status()

response_data = response.json()

message_content = response_data["choices"][0]["message"]["content"]
usage_data = response_data.get("usage", {})

usage = self._prepare_usage_data(usage_data)
conversation.add_message(AgentMessage(content=message_content, usage=usage))
Expand Down Expand Up @@ -178,7 +183,7 @@ async def apredict(
Conversation: Updated conversation with the model's response.
"""
formatted_messages = self._format_messages(conversation.history)
data = {
payload = {
"model": self.name,
"messages": formatted_messages,
"temperature": temperature,
Expand All @@ -187,12 +192,15 @@ async def apredict(
"stop": stop or [],
}
if enable_json:
data["response_format"] = "json_object"
payload["response_format"] = "json_object"

# Use asyncio's to_thread to call synchronous code in an async context
result = await asyncio.to_thread(self._make_request, data)
message_content = result["choices"][0]["message"]["content"]
usage_data = result.get("usage", {})
response = await self._async_client.post(self._BASE_URL, json=payload)
response.raise_for_status()

response_data = response.json()

message_content = response_data["choices"][0]["message"]["content"]
usage_data = response_data.get("usage", {})

usage = self._prepare_usage_data(usage_data)
conversation.add_message(AgentMessage(content=message_content, usage=usage))
Expand Down Expand Up @@ -223,7 +231,7 @@ def stream(
"""

formatted_messages = self._format_messages(conversation.history)
data = {
payload = {
"model": self.name,
"messages": formatted_messages,
"temperature": temperature,
Expand All @@ -233,31 +241,26 @@ def stream(
"stop": stop or [],
}
if enable_json:
data["response_format"] = "json_object"
payload["response_format"] = "json_object"

headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
response = self._client.post(self._BASE_URL, json=payload)

with requests.post(
self._api_url, headers=headers, json=data, stream=True
) as response:
response.raise_for_status()
message_content = ""
for line in response.iter_lines(decode_unicode=True):
json_str = line.replace('data: ', '')
try:
if json_str:
chunk = json.loads(json_str)
if chunk["choices"][0]["delta"]:
delta = chunk["choices"][0]["delta"]["content"]
message_content += delta
yield delta
except json.JSONDecodeError:
pass

conversation.add_message(AgentMessage(content=message_content))
response.raise_for_status()

message_content = ""
for line in response.iter_lines():
json_str = line.replace('data: ', '')
try:
if json_str:
chunk = json.loads(json_str)
if chunk["choices"][0]["delta"]:
delta = chunk["choices"][0]["delta"]["content"]
message_content += delta
yield delta
except json.JSONDecodeError:
pass

conversation.add_message(AgentMessage(content=message_content))

async def astream(
self,
Expand All @@ -284,7 +287,7 @@ async def astream(
"""

formatted_messages = self._format_messages(conversation.history)
data = {
payload = {
"model": self.name,
"messages": formatted_messages,
"temperature": temperature,
Expand All @@ -294,29 +297,24 @@ async def astream(
"stop": stop or [],
}
if enable_json:
data["response_format"] = "json_object"

headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}

async with httpx.AsyncClient() as client:
response = await client.post(self._api_url, headers=headers, json=data, timeout=None)

response.raise_for_status()
message_content = ""
async for line in response.aiter_lines():
json_str = line.replace('data: ', '')
try:
if json_str:
chunk = json.loads(json_str)
if chunk["choices"][0]["delta"]:
delta = chunk["choices"][0]["delta"]["content"]
message_content += delta
yield delta
except json.JSONDecodeError:
pass
payload["response_format"] = "json_object"

response = await self._async_client.post(self._BASE_URL, json=payload)

response.raise_for_status()
message_content = ""

async for line in response.aiter_lines():
json_str = line.replace('data: ', '')
try:
if json_str:
chunk = json.loads(json_str)
if chunk["choices"][0]["delta"]:
delta = chunk["choices"][0]["delta"]["content"]
message_content += delta
yield delta
except json.JSONDecodeError:
pass

conversation.add_message(AgentMessage(content=message_content))

Expand Down
Loading