Skip to content

Commit

Permalink
Merge branch 'mono/dev' into comm_pyproject
Browse files Browse the repository at this point in the history
  • Loading branch information
cobycloud authored Nov 20, 2024
2 parents 7b91934 + 547032c commit 1dd0d85
Show file tree
Hide file tree
Showing 2 changed files with 539 additions and 0 deletions.
381 changes: 381 additions & 0 deletions pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicVisionModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,381 @@
import json
from pydantic import PrivateAttr
import httpx
from typing import List, Optional, Dict, Literal, Any, AsyncGenerator, Generator
import asyncio

from swarmauri_core.typing import SubclassUnion
from swarmauri.conversations.concrete.Conversation import Conversation
from swarmauri.messages.base.MessageBase import MessageBase
from swarmauri.messages.concrete.AgentMessage import AgentMessage
from swarmauri.llms.base.LLMBase import LLMBase
from swarmauri.messages.concrete.AgentMessage import UsageData
from swarmauri.utils.retry_decorator import retry_on_status_codes
from swarmauri.utils.file_path_to_base64 import file_path_to_base64


class HyperbolicVisionModel(LLMBase):
"""
HyperbolicVisionModel class for interacting with the Hyperbolic vision language models API. This class
provides synchronous and asynchronous methods to send conversation data to the
model, receive predictions, and stream responses.
Attributes:
api_key (str): API key for authenticating requests to the Hyperbolic API.
allowed_models (List[str]): List of allowed model names that can be used.
name (str): The default model name to use for predictions.
type (Literal["HyperbolicVisionModel"]): The type identifier for this class.
Link to Allowed Models: https://app.hyperbolic.xyz/models
Link to API KEYS: https://app.hyperbolic.xyz/settings
"""

api_key: str
allowed_models: List[str] = [
"Qwen/Qwen2-VL-72B-Instruct",
"mistralai/Pixtral-12B-2409",
"Qwen/Qwen2-VL-7B-Instruct",
]
name: str = "Qwen/Qwen2-VL-72B-Instruct"
type: Literal["HyperbolicVisionModel"] = "HyperbolicVisionModel"
_headers: Dict[str, str] = PrivateAttr(default=None)
_client: httpx.Client = PrivateAttr(default=None)
_BASE_URL: str = PrivateAttr(
default="https://api.hyperbolic.xyz/v1/chat/completions"
)

def __init__(self, **data):
"""
Initialize the HyperbolicVisionModel class with the provided data.
Args:
**data: Arbitrary keyword arguments containing initialization data.
"""
super().__init__(**data)
self._headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
self._client = httpx.Client(
headers=self._headers,
base_url=self._BASE_URL,
)

def _format_messages(
self,
messages: List[SubclassUnion[MessageBase]],
) -> List[Dict[str, Any]]:
"""
Formats conversation messages into the structure expected by the API.
Args:
messages (List[MessageBase]): List of message objects from the conversation history.
Returns:
List[Dict[str, Any]]: List of formatted message dictionaries.
"""
formatted_messages = []
for message in messages:
formatted_message = message.model_dump(
include=["content", "role", "name"], exclude_none=True
)

if isinstance(formatted_message["content"], list):
formatted_content = []
for item in formatted_message["content"]:
if item["type"] == "image_url" and "file_path" in item:
# Convert file path to base64
base64_img = file_path_to_base64(item["file_path"])
formatted_content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_img}"
},
}
)
else:
formatted_content.append(item)
formatted_message["content"] = formatted_content

formatted_messages.append(formatted_message)
return formatted_messages

def _prepare_usage_data(self, usage_data) -> UsageData:
"""
Prepares and validates usage data received from the API response.
Args:
usage_data (dict): Raw usage data from the API response.
Returns:
UsageData: Validated usage data instance.
"""
return UsageData.model_validate(usage_data)

@retry_on_status_codes((429, 529), max_retries=1)
def predict(
self,
conversation: Conversation,
temperature: float = 0.7,
max_tokens: int = 2048,
top_p: float = 0.9,
stop: Optional[List[str]] = None,
) -> Conversation:
"""
Generates a response from the model based on the given conversation.
Args:
conversation (Conversation): Conversation object with message history.
temperature (float): Sampling temperature for response diversity.
max_tokens (int): Maximum tokens for the model's response.
top_p (float): Cumulative probability for nucleus sampling.
stop (Optional[List[str]]): List of stop sequences for response termination.
Returns:
Conversation: Updated conversation with the model's response.
"""
formatted_messages = self._format_messages(conversation.history)
payload = {
"model": self.name,
"messages": formatted_messages,
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"stop": stop or [],
}

response = self._client.post(self._BASE_URL, json=payload)
response.raise_for_status()

response_data = response.json()

message_content = response_data["choices"][0]["message"]["content"]
usage_data = response_data.get("usage", {})

usage = self._prepare_usage_data(usage_data)
conversation.add_message(AgentMessage(content=message_content, usage=usage))
return conversation

@retry_on_status_codes((429, 529), max_retries=1)
async def apredict(
self,
conversation: Conversation,
temperature: float = 0.7,
max_tokens: int = 2048,
top_p: float = 0.9,
stop: Optional[List[str]] = None,
) -> Conversation:
"""
Async method to generate a response from the model based on the given conversation.
Args:
conversation (Conversation): Conversation object with message history.
temperature (float): Sampling temperature for response diversity.
max_tokens (int): Maximum tokens for the model's response.
top_p (float): Cumulative probability for nucleus sampling.
stop (Optional[List[str]]): List of stop sequences for response termination.
Returns:
Conversation: Updated conversation with the model's response.
"""
formatted_messages = self._format_messages(conversation.history)
payload = {
"model": self.name,
"messages": formatted_messages,
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"stop": stop or [],
}

async with httpx.AsyncClient() as async_client:
response = await async_client.post(
self._BASE_URL, json=payload, headers=self._headers
)
response.raise_for_status()

response_data = response.json()

message_content = response_data["choices"][0]["message"]["content"]
usage_data = response_data.get("usage", {})

usage = self._prepare_usage_data(usage_data)
conversation.add_message(AgentMessage(content=message_content, usage=usage))
return conversation

@retry_on_status_codes((429, 529), max_retries=1)
def stream(
self,
conversation: Conversation,
temperature: float = 0.7,
max_tokens: int = 2048,
top_p: float = 0.9,
stop: Optional[List[str]] = None,
) -> Generator[str, None, None]:
"""
Streams response text from the model in real-time.
Args:
conversation (Conversation): Conversation object with message history.
temperature (float): Sampling temperature for response diversity.
max_tokens (int): Maximum tokens for the model's response.
top_p (float): Cumulative probability for nucleus sampling.
stop (Optional[List[str]]): List of stop sequences for response termination.
Yields:
str: Partial response content from the model.
"""
formatted_messages = self._format_messages(conversation.history)
payload = {
"model": self.name,
"messages": formatted_messages,
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"stream": True,
"stop": stop or [],
}

response = self._client.post(self._BASE_URL, json=payload)
response.raise_for_status()

message_content = ""
for line in response.iter_lines():
json_str = line.replace("data: ", "")
try:
if json_str:
chunk = json.loads(json_str)
if chunk["choices"][0]["delta"]:
delta = chunk["choices"][0]["delta"]["content"]
message_content += delta
yield delta
except json.JSONDecodeError:
pass

conversation.add_message(AgentMessage(content=message_content))

@retry_on_status_codes((429, 529), max_retries=1)
async def astream(
self,
conversation: Conversation,
temperature: float = 0.7,
max_tokens: int = 2048,
top_p: float = 0.9,
stop: Optional[List[str]] = None,
) -> AsyncGenerator[str, None]:
"""
Async generator that streams response text from the model in real-time.
Args:
conversation (Conversation): Conversation object with message history.
temperature (float): Sampling temperature for response diversity.
max_tokens (int): Maximum tokens for the model's response.
top_p (float): Cumulative probability for nucleus sampling.
stop (Optional[List[str]]): List of stop sequences for response termination.
Yields:
str: Partial response content from the model.
"""
formatted_messages = self._format_messages(conversation.history)
payload = {
"model": self.name,
"messages": formatted_messages,
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"stream": True,
"stop": stop or [],
}

async with httpx.AsyncClient as async_client:
response = await async_client.post(
self._BASE_URL, json=payload, headers=self._headers
)
response.raise_for_status()

message_content = ""
async for line in response.aiter_lines():
json_str = line.replace("data: ", "")
try:
if json_str:
chunk = json.loads(json_str)
if chunk["choices"][0]["delta"]:
delta = chunk["choices"][0]["delta"]["content"]
message_content += delta
yield delta
except json.JSONDecodeError:
pass

conversation.add_message(AgentMessage(content=message_content))

def batch(
self,
conversations: List[Conversation],
temperature: float = 0.7,
max_tokens: int = 2048,
top_p: float = 0.9,
stop: Optional[List[str]] = None,
) -> List[Conversation]:
"""
Processes a batch of conversations and generates responses for each sequentially.
Args:
conversations (List[Conversation]): List of conversations to process.
temperature (float): Sampling temperature for response diversity.
max_tokens (int): Maximum tokens for each response.
top_p (float): Cumulative probability for nucleus sampling.
stop (Optional[List[str]]): List of stop sequences for response termination.
Returns:
List[Conversation]: List of updated conversations with model responses.
"""
results = []
for conversation in conversations:
result_conversation = self.predict(
conversation,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stop=stop,
)
results.append(result_conversation)
return results

async def abatch(
self,
conversations: List[Conversation],
temperature: float = 0.7,
max_tokens: int = 2048,
top_p: float = 0.9,
stop: Optional[List[str]] = None,
max_concurrent=5,
) -> List[Conversation]:
"""
Async method for processing a batch of conversations concurrently.
Args:
conversations (List[Conversation]): List of conversations to process.
temperature (float): Sampling temperature for response diversity.
max_tokens (int): Maximum tokens for each response.
top_p (float): Cumulative probability for nucleus sampling.
stop (Optional[List[str]]): List of stop sequences for response termination.
max_concurrent (int): Maximum number of concurrent requests.
Returns:
List[Conversation]: List of updated conversations with model responses.
"""
semaphore = asyncio.Semaphore(max_concurrent)

async def process_conversation(conv: Conversation) -> Conversation:
async with semaphore:
return await self.apredict(
conv,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stop=stop,
)

tasks = [process_conversation(conv) for conv in conversations]
return await asyncio.gather(*tasks)
Loading

0 comments on commit 1dd0d85

Please sign in to comment.