From ab4aca2c475543e96a1aafe589ee06e3c72605c7 Mon Sep 17 00:00:00 2001 From: 3rdSon Date: Wed, 20 Nov 2024 15:08:23 +0100 Subject: [PATCH 1/2] swarm - implemented hyperbolicvision --- .../llms/concrete/HyperbolicVisionModel.py | 378 ++++++++++++++++++ .../llms/HyperbolicVisionModel_unit_test.py | 158 ++++++++ 2 files changed, 536 insertions(+) create mode 100644 pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicVisionModel.py create mode 100644 pkgs/swarmauri/tests/unit/llms/HyperbolicVisionModel_unit_test.py diff --git a/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicVisionModel.py b/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicVisionModel.py new file mode 100644 index 000000000..0dab82ce3 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicVisionModel.py @@ -0,0 +1,378 @@ +import json +from pydantic import PrivateAttr +import httpx +from typing import List, Optional, Dict, Literal, Any, AsyncGenerator, Generator +import asyncio + +from swarmauri_core.typing import SubclassUnion +from swarmauri.conversations.concrete.Conversation import Conversation +from swarmauri.messages.base.MessageBase import MessageBase +from swarmauri.messages.concrete.AgentMessage import AgentMessage +from swarmauri.llms.base.LLMBase import LLMBase +from swarmauri.messages.concrete.AgentMessage import UsageData +from swarmauri.utils.retry_decorator import retry_on_status_codes +from swarmauri.utils.file_path_to_base64 import file_path_to_base64 + + +class HyperbolicVisionModel(LLMBase): + """ + HyperbolicVisionModel class for interacting with the Hyperbolic vision language models API. This class + provides synchronous and asynchronous methods to send conversation data to the + model, receive predictions, and stream responses. + + Attributes: + api_key (str): API key for authenticating requests to the Hyperbolic API. + allowed_models (List[str]): List of allowed model names that can be used. + name (str): The default model name to use for predictions. + type (Literal["HyperbolicVisionModel"]): The type identifier for this class. + """ + + api_key: str + allowed_models: List[str] = [ + "Qwen/Qwen2-VL-72B-Instruct", + "mistralai/Pixtral-12B-2409", + "Qwen/Qwen2-VL-7B-Instruct", + ] + name: str = "Qwen/Qwen2-VL-72B-Instruct" + type: Literal["HyperbolicVisionModel"] = "HyperbolicVisionModel" + _headers: Dict[str, str] = PrivateAttr(default=None) + _client: httpx.Client = PrivateAttr(default=None) + _BASE_URL: str = PrivateAttr( + default="https://api.hyperbolic.xyz/v1/chat/completions" + ) + + def __init__(self, **data): + """ + Initialize the HyperbolicVisionModel class with the provided data. + + Args: + **data: Arbitrary keyword arguments containing initialization data. + """ + super().__init__(**data) + self._headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + } + self._client = httpx.Client( + headers=self._headers, + base_url=self._BASE_URL, + ) + + def _format_messages( + self, + messages: List[SubclassUnion[MessageBase]], + ) -> List[Dict[str, Any]]: + """ + Formats conversation messages into the structure expected by the API. + + Args: + messages (List[MessageBase]): List of message objects from the conversation history. + + Returns: + List[Dict[str, Any]]: List of formatted message dictionaries. + """ + formatted_messages = [] + for message in messages: + formatted_message = message.model_dump( + include=["content", "role", "name"], exclude_none=True + ) + + if isinstance(formatted_message["content"], list): + formatted_content = [] + for item in formatted_message["content"]: + if item["type"] == "image_url" and "file_path" in item: + # Convert file path to base64 + base64_img = file_path_to_base64(item["file_path"]) + formatted_content.append( + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_img}" + }, + } + ) + else: + formatted_content.append(item) + formatted_message["content"] = formatted_content + + formatted_messages.append(formatted_message) + return formatted_messages + + def _prepare_usage_data(self, usage_data) -> UsageData: + """ + Prepares and validates usage data received from the API response. + + Args: + usage_data (dict): Raw usage data from the API response. + + Returns: + UsageData: Validated usage data instance. + """ + return UsageData.model_validate(usage_data) + + @retry_on_status_codes((429, 529), max_retries=1) + def predict( + self, + conversation: Conversation, + temperature: float = 0.7, + max_tokens: int = 2048, + top_p: float = 0.9, + stop: Optional[List[str]] = None, + ) -> Conversation: + """ + Generates a response from the model based on the given conversation. + + Args: + conversation (Conversation): Conversation object with message history. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for the model's response. + top_p (float): Cumulative probability for nucleus sampling. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Returns: + Conversation: Updated conversation with the model's response. + """ + formatted_messages = self._format_messages(conversation.history) + payload = { + "model": self.name, + "messages": formatted_messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "stop": stop or [], + } + + response = self._client.post(self._BASE_URL, json=payload) + response.raise_for_status() + + response_data = response.json() + + message_content = response_data["choices"][0]["message"]["content"] + usage_data = response_data.get("usage", {}) + + usage = self._prepare_usage_data(usage_data) + conversation.add_message(AgentMessage(content=message_content, usage=usage)) + return conversation + + @retry_on_status_codes((429, 529), max_retries=1) + async def apredict( + self, + conversation: Conversation, + temperature: float = 0.7, + max_tokens: int = 2048, + top_p: float = 0.9, + stop: Optional[List[str]] = None, + ) -> Conversation: + """ + Async method to generate a response from the model based on the given conversation. + + Args: + conversation (Conversation): Conversation object with message history. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for the model's response. + top_p (float): Cumulative probability for nucleus sampling. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Returns: + Conversation: Updated conversation with the model's response. + """ + formatted_messages = self._format_messages(conversation.history) + payload = { + "model": self.name, + "messages": formatted_messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "stop": stop or [], + } + + async with httpx.AsyncClient() as async_client: + response = await async_client.post( + self._BASE_URL, json=payload, headers=self._headers + ) + response.raise_for_status() + + response_data = response.json() + + message_content = response_data["choices"][0]["message"]["content"] + usage_data = response_data.get("usage", {}) + + usage = self._prepare_usage_data(usage_data) + conversation.add_message(AgentMessage(content=message_content, usage=usage)) + return conversation + + @retry_on_status_codes((429, 529), max_retries=1) + def stream( + self, + conversation: Conversation, + temperature: float = 0.7, + max_tokens: int = 2048, + top_p: float = 0.9, + stop: Optional[List[str]] = None, + ) -> Generator[str, None, None]: + """ + Streams response text from the model in real-time. + + Args: + conversation (Conversation): Conversation object with message history. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for the model's response. + top_p (float): Cumulative probability for nucleus sampling. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Yields: + str: Partial response content from the model. + """ + formatted_messages = self._format_messages(conversation.history) + payload = { + "model": self.name, + "messages": formatted_messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "stream": True, + "stop": stop or [], + } + + response = self._client.post(self._BASE_URL, json=payload) + response.raise_for_status() + + message_content = "" + for line in response.iter_lines(): + json_str = line.replace("data: ", "") + try: + if json_str: + chunk = json.loads(json_str) + if chunk["choices"][0]["delta"]: + delta = chunk["choices"][0]["delta"]["content"] + message_content += delta + yield delta + except json.JSONDecodeError: + pass + + conversation.add_message(AgentMessage(content=message_content)) + + @retry_on_status_codes((429, 529), max_retries=1) + async def astream( + self, + conversation: Conversation, + temperature: float = 0.7, + max_tokens: int = 2048, + top_p: float = 0.9, + stop: Optional[List[str]] = None, + ) -> AsyncGenerator[str, None]: + """ + Async generator that streams response text from the model in real-time. + + Args: + conversation (Conversation): Conversation object with message history. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for the model's response. + top_p (float): Cumulative probability for nucleus sampling. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Yields: + str: Partial response content from the model. + """ + formatted_messages = self._format_messages(conversation.history) + payload = { + "model": self.name, + "messages": formatted_messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "stream": True, + "stop": stop or [], + } + + async with httpx.AsyncClient as async_client: + response = await async_client.post( + self._BASE_URL, json=payload, headers=self._headers + ) + response.raise_for_status() + + message_content = "" + async for line in response.aiter_lines(): + json_str = line.replace("data: ", "") + try: + if json_str: + chunk = json.loads(json_str) + if chunk["choices"][0]["delta"]: + delta = chunk["choices"][0]["delta"]["content"] + message_content += delta + yield delta + except json.JSONDecodeError: + pass + + conversation.add_message(AgentMessage(content=message_content)) + + def batch( + self, + conversations: List[Conversation], + temperature: float = 0.7, + max_tokens: int = 2048, + top_p: float = 0.9, + stop: Optional[List[str]] = None, + ) -> List[Conversation]: + """ + Processes a batch of conversations and generates responses for each sequentially. + + Args: + conversations (List[Conversation]): List of conversations to process. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for each response. + top_p (float): Cumulative probability for nucleus sampling. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Returns: + List[Conversation]: List of updated conversations with model responses. + """ + results = [] + for conversation in conversations: + result_conversation = self.predict( + conversation, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + stop=stop, + ) + results.append(result_conversation) + return results + + async def abatch( + self, + conversations: List[Conversation], + temperature: float = 0.7, + max_tokens: int = 2048, + top_p: float = 0.9, + stop: Optional[List[str]] = None, + max_concurrent=5, + ) -> List[Conversation]: + """ + Async method for processing a batch of conversations concurrently. + + Args: + conversations (List[Conversation]): List of conversations to process. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for each response. + top_p (float): Cumulative probability for nucleus sampling. + stop (Optional[List[str]]): List of stop sequences for response termination. + max_concurrent (int): Maximum number of concurrent requests. + + Returns: + List[Conversation]: List of updated conversations with model responses. + """ + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_conversation(conv: Conversation) -> Conversation: + async with semaphore: + return await self.apredict( + conv, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + stop=stop, + ) + + tasks = [process_conversation(conv) for conv in conversations] + return await asyncio.gather(*tasks) diff --git a/pkgs/swarmauri/tests/unit/llms/HyperbolicVisionModel_unit_test.py b/pkgs/swarmauri/tests/unit/llms/HyperbolicVisionModel_unit_test.py new file mode 100644 index 000000000..495341aae --- /dev/null +++ b/pkgs/swarmauri/tests/unit/llms/HyperbolicVisionModel_unit_test.py @@ -0,0 +1,158 @@ +import pytest +import os +from swarmauri.llms.concrete.HyperbolicVisionModel import HyperbolicVisionModel +from swarmauri.conversations.concrete.Conversation import Conversation +from swarmauri.messages.concrete.HumanMessage import HumanMessage +from dotenv import load_dotenv +from swarmauri.utils.timeout_wrapper import timeout + +load_dotenv() + +API_KEY = os.getenv("HYPERBOLIC_API_KEY") + + +@pytest.fixture(scope="module") +def hyperbolic_vision_model(): + if not API_KEY: + pytest.skip("Skipping due to environment variable not set") + model = HyperbolicVisionModel(api_key=API_KEY) + return model + + +def get_allowed_models(): + if not API_KEY: + return [] + model = HyperbolicVisionModel(api_key=API_KEY) + return model.allowed_models + + +@timeout(5) +@pytest.mark.unit +def test_ubc_resource(hyperbolic_vision_model): + assert hyperbolic_vision_model.resource == "LLM" + + +@timeout(5) +@pytest.mark.unit +def test_ubc_type(hyperbolic_vision_model): + assert hyperbolic_vision_model.type == "HyperbolicVisionModel" + + +@timeout(5) +@pytest.mark.unit +def test_serialization(hyperbolic_vision_model): + assert ( + hyperbolic_vision_model.id + == HyperbolicVisionModel.model_validate_json( + hyperbolic_vision_model.model_dump_json() + ).id + ) + + +@timeout(5) +@pytest.mark.unit +def test_default_model_name(hyperbolic_vision_model): + assert hyperbolic_vision_model.name == "Qwen/Qwen2-VL-72B-Instruct" + + +def create_test_conversation(image_url, prompt): + conversation = Conversation() + conversation.add_message( + HumanMessage( + content=[ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": image_url}}, + ] + ) + ) + return conversation + + +@pytest.mark.parametrize("model_name", get_allowed_models()) +@timeout(5) +@pytest.mark.unit +def test_predict(hyperbolic_vision_model, model_name): + model = hyperbolic_vision_model + model.name = model_name + + image_url = "https://llava-vl.github.io/static/images/monalisa.jpg" + prompt = "Who painted this artwork?" + conversation = create_test_conversation(image_url, prompt) + + result = model.predict(conversation) + + assert result.history[-1].content is not None + assert isinstance(result.history[-1].content, str) + assert len(result.history[-1].content) > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", get_allowed_models()) +@timeout(5) +@pytest.mark.unit +async def test_apredict(hyperbolic_vision_model, model_name): + model = hyperbolic_vision_model + model.name = model_name + + image_url = "https://llava-vl.github.io/static/images/monalisa.jpg" + prompt = "Describe the woman in the painting." + conversation = create_test_conversation(image_url, prompt) + + result = await model.apredict(conversation) + + assert result.history[-1].content is not None + assert isinstance(result.history[-1].content, str) + assert len(result.history[-1].content) > 0 + + +@timeout(5) +@pytest.mark.unit +def test_batch(hyperbolic_vision_model): + image_urls = [ + "https://llava-vl.github.io/static/images/monalisa.jpg", + "https://llava-vl.github.io/static/images/monalisa.jpg", + ] + prompts = [ + "Who painted this artwork?", + "Describe the woman in the painting.", + ] + + conversations = [ + create_test_conversation(image_url, prompt) + for image_url, prompt in zip(image_urls, prompts) + ] + + results = hyperbolic_vision_model.batch(conversations) + + assert len(results) == len(image_urls) + for result in results: + assert result.history[-1].content is not None + assert isinstance(result.history[-1].content, str) + assert len(result.history[-1].content) > 0 + + +@pytest.mark.asyncio +@timeout(5) +@pytest.mark.unit +async def test_abatch(hyperbolic_vision_model): + image_urls = [ + "https://llava-vl.github.io/static/images/monalisa.jpg", + "https://llava-vl.github.io/static/images/monalisa.jpg", + ] + prompts = [ + "Who painted this artwork?", + "Describe the woman in the painting.", + ] + + conversations = [ + create_test_conversation(image_url, prompt) + for image_url, prompt in zip(image_urls, prompts) + ] + + results = await hyperbolic_vision_model.abatch(conversations) + + assert len(results) == len(image_urls) + for result in results: + assert result.history[-1].content is not None + assert isinstance(result.history[-1].content, str) + assert len(result.history[-1].content) > 0 From a17005800b6411640bba6f7bae1a3358c500b1ef Mon Sep 17 00:00:00 2001 From: 3rdSon Date: Wed, 20 Nov 2024 15:11:45 +0100 Subject: [PATCH 2/2] swarm - hyperbolic --- .../swarmauri/swarmauri/llms/concrete/HyperbolicVisionModel.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicVisionModel.py b/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicVisionModel.py index 0dab82ce3..14e2d196a 100644 --- a/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicVisionModel.py +++ b/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicVisionModel.py @@ -25,6 +25,9 @@ class HyperbolicVisionModel(LLMBase): allowed_models (List[str]): List of allowed model names that can be used. name (str): The default model name to use for predictions. type (Literal["HyperbolicVisionModel"]): The type identifier for this class. + + Link to Allowed Models: https://app.hyperbolic.xyz/models + Link to API KEYS: https://app.hyperbolic.xyz/settings """ api_key: str