-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'mono/dev' into comm_pyproject
- Loading branch information
Showing
2 changed files
with
539 additions
and
0 deletions.
There are no files selected for viewing
381 changes: 381 additions & 0 deletions
381
pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicVisionModel.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,381 @@ | ||
import json | ||
from pydantic import PrivateAttr | ||
import httpx | ||
from typing import List, Optional, Dict, Literal, Any, AsyncGenerator, Generator | ||
import asyncio | ||
|
||
from swarmauri_core.typing import SubclassUnion | ||
from swarmauri.conversations.concrete.Conversation import Conversation | ||
from swarmauri.messages.base.MessageBase import MessageBase | ||
from swarmauri.messages.concrete.AgentMessage import AgentMessage | ||
from swarmauri.llms.base.LLMBase import LLMBase | ||
from swarmauri.messages.concrete.AgentMessage import UsageData | ||
from swarmauri.utils.retry_decorator import retry_on_status_codes | ||
from swarmauri.utils.file_path_to_base64 import file_path_to_base64 | ||
|
||
|
||
class HyperbolicVisionModel(LLMBase): | ||
""" | ||
HyperbolicVisionModel class for interacting with the Hyperbolic vision language models API. This class | ||
provides synchronous and asynchronous methods to send conversation data to the | ||
model, receive predictions, and stream responses. | ||
Attributes: | ||
api_key (str): API key for authenticating requests to the Hyperbolic API. | ||
allowed_models (List[str]): List of allowed model names that can be used. | ||
name (str): The default model name to use for predictions. | ||
type (Literal["HyperbolicVisionModel"]): The type identifier for this class. | ||
Link to Allowed Models: https://app.hyperbolic.xyz/models | ||
Link to API KEYS: https://app.hyperbolic.xyz/settings | ||
""" | ||
|
||
api_key: str | ||
allowed_models: List[str] = [ | ||
"Qwen/Qwen2-VL-72B-Instruct", | ||
"mistralai/Pixtral-12B-2409", | ||
"Qwen/Qwen2-VL-7B-Instruct", | ||
] | ||
name: str = "Qwen/Qwen2-VL-72B-Instruct" | ||
type: Literal["HyperbolicVisionModel"] = "HyperbolicVisionModel" | ||
_headers: Dict[str, str] = PrivateAttr(default=None) | ||
_client: httpx.Client = PrivateAttr(default=None) | ||
_BASE_URL: str = PrivateAttr( | ||
default="https://api.hyperbolic.xyz/v1/chat/completions" | ||
) | ||
|
||
def __init__(self, **data): | ||
""" | ||
Initialize the HyperbolicVisionModel class with the provided data. | ||
Args: | ||
**data: Arbitrary keyword arguments containing initialization data. | ||
""" | ||
super().__init__(**data) | ||
self._headers = { | ||
"Content-Type": "application/json", | ||
"Authorization": f"Bearer {self.api_key}", | ||
} | ||
self._client = httpx.Client( | ||
headers=self._headers, | ||
base_url=self._BASE_URL, | ||
) | ||
|
||
def _format_messages( | ||
self, | ||
messages: List[SubclassUnion[MessageBase]], | ||
) -> List[Dict[str, Any]]: | ||
""" | ||
Formats conversation messages into the structure expected by the API. | ||
Args: | ||
messages (List[MessageBase]): List of message objects from the conversation history. | ||
Returns: | ||
List[Dict[str, Any]]: List of formatted message dictionaries. | ||
""" | ||
formatted_messages = [] | ||
for message in messages: | ||
formatted_message = message.model_dump( | ||
include=["content", "role", "name"], exclude_none=True | ||
) | ||
|
||
if isinstance(formatted_message["content"], list): | ||
formatted_content = [] | ||
for item in formatted_message["content"]: | ||
if item["type"] == "image_url" and "file_path" in item: | ||
# Convert file path to base64 | ||
base64_img = file_path_to_base64(item["file_path"]) | ||
formatted_content.append( | ||
{ | ||
"type": "image_url", | ||
"image_url": { | ||
"url": f"data:image/jpeg;base64,{base64_img}" | ||
}, | ||
} | ||
) | ||
else: | ||
formatted_content.append(item) | ||
formatted_message["content"] = formatted_content | ||
|
||
formatted_messages.append(formatted_message) | ||
return formatted_messages | ||
|
||
def _prepare_usage_data(self, usage_data) -> UsageData: | ||
""" | ||
Prepares and validates usage data received from the API response. | ||
Args: | ||
usage_data (dict): Raw usage data from the API response. | ||
Returns: | ||
UsageData: Validated usage data instance. | ||
""" | ||
return UsageData.model_validate(usage_data) | ||
|
||
@retry_on_status_codes((429, 529), max_retries=1) | ||
def predict( | ||
self, | ||
conversation: Conversation, | ||
temperature: float = 0.7, | ||
max_tokens: int = 2048, | ||
top_p: float = 0.9, | ||
stop: Optional[List[str]] = None, | ||
) -> Conversation: | ||
""" | ||
Generates a response from the model based on the given conversation. | ||
Args: | ||
conversation (Conversation): Conversation object with message history. | ||
temperature (float): Sampling temperature for response diversity. | ||
max_tokens (int): Maximum tokens for the model's response. | ||
top_p (float): Cumulative probability for nucleus sampling. | ||
stop (Optional[List[str]]): List of stop sequences for response termination. | ||
Returns: | ||
Conversation: Updated conversation with the model's response. | ||
""" | ||
formatted_messages = self._format_messages(conversation.history) | ||
payload = { | ||
"model": self.name, | ||
"messages": formatted_messages, | ||
"temperature": temperature, | ||
"max_tokens": max_tokens, | ||
"top_p": top_p, | ||
"stop": stop or [], | ||
} | ||
|
||
response = self._client.post(self._BASE_URL, json=payload) | ||
response.raise_for_status() | ||
|
||
response_data = response.json() | ||
|
||
message_content = response_data["choices"][0]["message"]["content"] | ||
usage_data = response_data.get("usage", {}) | ||
|
||
usage = self._prepare_usage_data(usage_data) | ||
conversation.add_message(AgentMessage(content=message_content, usage=usage)) | ||
return conversation | ||
|
||
@retry_on_status_codes((429, 529), max_retries=1) | ||
async def apredict( | ||
self, | ||
conversation: Conversation, | ||
temperature: float = 0.7, | ||
max_tokens: int = 2048, | ||
top_p: float = 0.9, | ||
stop: Optional[List[str]] = None, | ||
) -> Conversation: | ||
""" | ||
Async method to generate a response from the model based on the given conversation. | ||
Args: | ||
conversation (Conversation): Conversation object with message history. | ||
temperature (float): Sampling temperature for response diversity. | ||
max_tokens (int): Maximum tokens for the model's response. | ||
top_p (float): Cumulative probability for nucleus sampling. | ||
stop (Optional[List[str]]): List of stop sequences for response termination. | ||
Returns: | ||
Conversation: Updated conversation with the model's response. | ||
""" | ||
formatted_messages = self._format_messages(conversation.history) | ||
payload = { | ||
"model": self.name, | ||
"messages": formatted_messages, | ||
"temperature": temperature, | ||
"max_tokens": max_tokens, | ||
"top_p": top_p, | ||
"stop": stop or [], | ||
} | ||
|
||
async with httpx.AsyncClient() as async_client: | ||
response = await async_client.post( | ||
self._BASE_URL, json=payload, headers=self._headers | ||
) | ||
response.raise_for_status() | ||
|
||
response_data = response.json() | ||
|
||
message_content = response_data["choices"][0]["message"]["content"] | ||
usage_data = response_data.get("usage", {}) | ||
|
||
usage = self._prepare_usage_data(usage_data) | ||
conversation.add_message(AgentMessage(content=message_content, usage=usage)) | ||
return conversation | ||
|
||
@retry_on_status_codes((429, 529), max_retries=1) | ||
def stream( | ||
self, | ||
conversation: Conversation, | ||
temperature: float = 0.7, | ||
max_tokens: int = 2048, | ||
top_p: float = 0.9, | ||
stop: Optional[List[str]] = None, | ||
) -> Generator[str, None, None]: | ||
""" | ||
Streams response text from the model in real-time. | ||
Args: | ||
conversation (Conversation): Conversation object with message history. | ||
temperature (float): Sampling temperature for response diversity. | ||
max_tokens (int): Maximum tokens for the model's response. | ||
top_p (float): Cumulative probability for nucleus sampling. | ||
stop (Optional[List[str]]): List of stop sequences for response termination. | ||
Yields: | ||
str: Partial response content from the model. | ||
""" | ||
formatted_messages = self._format_messages(conversation.history) | ||
payload = { | ||
"model": self.name, | ||
"messages": formatted_messages, | ||
"temperature": temperature, | ||
"max_tokens": max_tokens, | ||
"top_p": top_p, | ||
"stream": True, | ||
"stop": stop or [], | ||
} | ||
|
||
response = self._client.post(self._BASE_URL, json=payload) | ||
response.raise_for_status() | ||
|
||
message_content = "" | ||
for line in response.iter_lines(): | ||
json_str = line.replace("data: ", "") | ||
try: | ||
if json_str: | ||
chunk = json.loads(json_str) | ||
if chunk["choices"][0]["delta"]: | ||
delta = chunk["choices"][0]["delta"]["content"] | ||
message_content += delta | ||
yield delta | ||
except json.JSONDecodeError: | ||
pass | ||
|
||
conversation.add_message(AgentMessage(content=message_content)) | ||
|
||
@retry_on_status_codes((429, 529), max_retries=1) | ||
async def astream( | ||
self, | ||
conversation: Conversation, | ||
temperature: float = 0.7, | ||
max_tokens: int = 2048, | ||
top_p: float = 0.9, | ||
stop: Optional[List[str]] = None, | ||
) -> AsyncGenerator[str, None]: | ||
""" | ||
Async generator that streams response text from the model in real-time. | ||
Args: | ||
conversation (Conversation): Conversation object with message history. | ||
temperature (float): Sampling temperature for response diversity. | ||
max_tokens (int): Maximum tokens for the model's response. | ||
top_p (float): Cumulative probability for nucleus sampling. | ||
stop (Optional[List[str]]): List of stop sequences for response termination. | ||
Yields: | ||
str: Partial response content from the model. | ||
""" | ||
formatted_messages = self._format_messages(conversation.history) | ||
payload = { | ||
"model": self.name, | ||
"messages": formatted_messages, | ||
"temperature": temperature, | ||
"max_tokens": max_tokens, | ||
"top_p": top_p, | ||
"stream": True, | ||
"stop": stop or [], | ||
} | ||
|
||
async with httpx.AsyncClient as async_client: | ||
response = await async_client.post( | ||
self._BASE_URL, json=payload, headers=self._headers | ||
) | ||
response.raise_for_status() | ||
|
||
message_content = "" | ||
async for line in response.aiter_lines(): | ||
json_str = line.replace("data: ", "") | ||
try: | ||
if json_str: | ||
chunk = json.loads(json_str) | ||
if chunk["choices"][0]["delta"]: | ||
delta = chunk["choices"][0]["delta"]["content"] | ||
message_content += delta | ||
yield delta | ||
except json.JSONDecodeError: | ||
pass | ||
|
||
conversation.add_message(AgentMessage(content=message_content)) | ||
|
||
def batch( | ||
self, | ||
conversations: List[Conversation], | ||
temperature: float = 0.7, | ||
max_tokens: int = 2048, | ||
top_p: float = 0.9, | ||
stop: Optional[List[str]] = None, | ||
) -> List[Conversation]: | ||
""" | ||
Processes a batch of conversations and generates responses for each sequentially. | ||
Args: | ||
conversations (List[Conversation]): List of conversations to process. | ||
temperature (float): Sampling temperature for response diversity. | ||
max_tokens (int): Maximum tokens for each response. | ||
top_p (float): Cumulative probability for nucleus sampling. | ||
stop (Optional[List[str]]): List of stop sequences for response termination. | ||
Returns: | ||
List[Conversation]: List of updated conversations with model responses. | ||
""" | ||
results = [] | ||
for conversation in conversations: | ||
result_conversation = self.predict( | ||
conversation, | ||
temperature=temperature, | ||
max_tokens=max_tokens, | ||
top_p=top_p, | ||
stop=stop, | ||
) | ||
results.append(result_conversation) | ||
return results | ||
|
||
async def abatch( | ||
self, | ||
conversations: List[Conversation], | ||
temperature: float = 0.7, | ||
max_tokens: int = 2048, | ||
top_p: float = 0.9, | ||
stop: Optional[List[str]] = None, | ||
max_concurrent=5, | ||
) -> List[Conversation]: | ||
""" | ||
Async method for processing a batch of conversations concurrently. | ||
Args: | ||
conversations (List[Conversation]): List of conversations to process. | ||
temperature (float): Sampling temperature for response diversity. | ||
max_tokens (int): Maximum tokens for each response. | ||
top_p (float): Cumulative probability for nucleus sampling. | ||
stop (Optional[List[str]]): List of stop sequences for response termination. | ||
max_concurrent (int): Maximum number of concurrent requests. | ||
Returns: | ||
List[Conversation]: List of updated conversations with model responses. | ||
""" | ||
semaphore = asyncio.Semaphore(max_concurrent) | ||
|
||
async def process_conversation(conv: Conversation) -> Conversation: | ||
async with semaphore: | ||
return await self.apredict( | ||
conv, | ||
temperature=temperature, | ||
max_tokens=max_tokens, | ||
top_p=top_p, | ||
stop=stop, | ||
) | ||
|
||
tasks = [process_conversation(conv) for conv in conversations] | ||
return await asyncio.gather(*tasks) |
Oops, something went wrong.