diff --git a/pkgs/swarmauri-partner-clients/llms/GroqModel/GroqModel.py b/pkgs/swarmauri-partner-clients/llms/GroqModel/GroqModel.py new file mode 100644 index 000000000..b837a0549 --- /dev/null +++ b/pkgs/swarmauri-partner-clients/llms/GroqModel/GroqModel.py @@ -0,0 +1,382 @@ +import asyncio +import json +from swarmauri.conversations.concrete.Conversation import Conversation +from typing import Generator, List, Optional, Dict, Literal, Any, Union, AsyncGenerator + +from groq import Groq, AsyncGroq +from swarmauri_core.typing import SubclassUnion + +from swarmauri.messages.base.MessageBase import MessageBase +from swarmauri.messages.concrete.AgentMessage import AgentMessage +from swarmauri.llms.base.LLMBase import LLMBase + +from swarmauri.messages.concrete.AgentMessage import UsageData + + +class GroqModel(LLMBase): + """ + GroqModel class for interacting with the Groq language models API. This class + provides synchronous and asynchronous methods to send conversation data to the + model, receive predictions, and stream responses. + + Attributes: + api_key (str): API key for authenticating requests to the Groq API. + allowed_models (List[str]): List of allowed model names that can be used. + name (str): The default model name to use for predictions. + type (Literal["GroqModel"]): The type identifier for this class. + + + Allowed Models resources: https://console.groq.com/docs/models + """ + + api_key: str + allowed_models: List[str] = [ + "gemma-7b-it", + "gemma2-9b-it", + "llama-3.1-70b-versatile", + "llama-3.1-8b-instant", + "llama-3.2-11b-text-preview", + "llama-3.2-1b-preview", + "llama-3.2-3b-preview", + "llama-3.2-90b-text-preview", + "llama-guard-3-8b", + "llama3-70b-8192", + "llama3-8b-8192", + "llama3-groq-70b-8192-tool-use-preview", + "llama3-groq-8b-8192-tool-use-preview", + "llava-v1.5-7b-4096-preview", + "mixtral-8x7b-32768", + ] + name: str = "gemma-7b-it" + type: Literal["GroqModel"] = "GroqModel" + + def _format_messages( + self, + messages: List[SubclassUnion[MessageBase]], + ) -> List[Dict[str, Any]]: + """ + Formats conversation messages into the structure expected by the API. + + Args: + messages (List[MessageBase]): List of message objects from the conversation history. + + Returns: + List[Dict[str, Any]]: List of formatted message dictionaries. + """ + formatted_messages = [] + for message in messages: + formatted_message = message.model_dump( + include=["content", "role", "name"], exclude_none=True + ) + + if isinstance(formatted_message["content"], list): + formatted_message["content"] = [ + {"type": item["type"], **item} + for item in formatted_message["content"] + ] + + formatted_messages.append(formatted_message) + return formatted_messages + + def _prepare_usage_data( + self, + usage_data, + ) -> UsageData: + """ + Prepares and validates usage data received from the API response. + + Args: + usage_data (dict): Raw usage data from the API response. + + Returns: + UsageData: Validated usage data instance. + """ + + usage = UsageData.model_validate(usage_data) + return usage + + def predict( + self, + conversation, + temperature: float = 0.7, + max_tokens: int = 256, + top_p: float = 1.0, + enable_json: bool = False, + stop: Optional[List[str]] = None, + ) -> Conversation: + """ + Generates a response from the model based on the given conversation. + + Args: + conversation (Conversation): Conversation object with message history. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for the model's response. + top_p (float): Cumulative probability for nucleus sampling. + enable_json (bool): Whether to format the response as JSON. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Returns: + Conversation: Updated conversation with the model's response. + """ + + formatted_messages = self._format_messages(conversation.history) + response_format = {"type": "json_object"} if enable_json else None + + kwargs = { + "model": self.name, + "messages": formatted_messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "response_format": response_format, + "stop": stop or [], + } + + client = Groq(api_key=self.api_key) + response = client.chat.completions.create(**kwargs) + + result = json.loads(response.model_dump_json()) + message_content = result["choices"][0]["message"]["content"] + usage_data = result.get("usage", {}) + + usage = self._prepare_usage_data(usage_data) + + conversation.add_message(AgentMessage(content=message_content, usage=usage)) + return conversation + + async def apredict( + self, + conversation, + temperature: float = 0.7, + max_tokens: int = 256, + top_p: float = 1.0, + enable_json: bool = False, + stop: Optional[List[str]] = None, + ) -> Conversation: + """ + Async method to generate a response from the model based on the given conversation. + + Args: + conversation (Conversation): Conversation object with message history. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for the model's response. + top_p (float): Cumulative probability for nucleus sampling. + enable_json (bool): Whether to format the response as JSON. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Returns: + Conversation: Updated conversation with the model's response. + """ + + formatted_messages = self._format_messages(conversation.history) + response_format = {"type": "json_object"} if enable_json else None + + kwargs = { + "model": self.name, + "messages": formatted_messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "response_format": response_format, + "stop": stop or [], + } + + client = AsyncGroq(api_key=self.api_key) + response = await client.chat.completions.create(**kwargs) + + result = json.loads(response.model_dump_json()) + message_content = result["choices"][0]["message"]["content"] + usage_data = result.get("usage", {}) + + usage = self._prepare_usage_data(usage_data) + + conversation.add_message(AgentMessage(content=message_content, usage=usage)) + return conversation + + def stream( + self, + conversation, + temperature: float = 0.7, + max_tokens: int = 256, + top_p: float = 1.0, + enable_json: bool = False, + stop: Optional[List[str]] = None, + ) -> Union[str, Generator[str, str, None]]: + """ + Streams response text from the model in real-time. + + Args: + conversation (Conversation): Conversation object with message history. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for the model's response. + top_p (float): Cumulative probability for nucleus sampling. + enable_json (bool): Whether to format the response as JSON. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Yields: + str: Partial response content from the model. + """ + + formatted_messages = self._format_messages(conversation.history) + response_format = {"type": "json_object"} if enable_json else None + + kwargs = { + "model": self.name, + "messages": formatted_messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "response_format": response_format, + "stream": True, + "stop": stop or [], + # "stream_options": {"include_usage": True}, + } + + client = Groq(api_key=self.api_key) + stream = client.chat.completions.create(**kwargs) + message_content = "" + # usage_data = {} + + for chunk in stream: + if chunk.choices and chunk.choices[0].delta.content: + message_content += chunk.choices[0].delta.content + yield chunk.choices[0].delta.content + + # if hasattr(chunk, "usage") and chunk.usage is not None: + # usage_data = chunk.usage + + # usage = self._prepare_usage_data(usage_data) + + conversation.add_message(AgentMessage(content=message_content)) + + async def astream( + self, + conversation, + temperature: float = 0.7, + max_tokens: int = 256, + top_p: float = 1.0, + enable_json: bool = False, + stop: Optional[List[str]] = None, + ) -> AsyncGenerator[str, None]: + """ + Async generator that streams response text from the model in real-time. + + Args: + conversation (Conversation): Conversation object with message history. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for the model's response. + top_p (float): Cumulative probability for nucleus sampling. + enable_json (bool): Whether to format the response as JSON. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Yields: + str: Partial response content from the model. + """ + + formatted_messages = self._format_messages(conversation.history) + response_format = {"type": "json_object"} if enable_json else None + + kwargs = { + "model": self.name, + "messages": formatted_messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "response_format": response_format, + "stop": stop or [], + "stream": True, + # "stream_options": {"include_usage": True}, + } + + client = AsyncGroq(api_key=self.api_key) + stream = await client.chat.completions.create(**kwargs) + + message_content = "" + # usage_data = {} + + async for chunk in stream: + if chunk.choices and chunk.choices[0].delta.content: + message_content += chunk.choices[0].delta.content + yield chunk.choices[0].delta.content + + # if hasattr(chunk, "usage") and chunk.usage is not None: + # usage_data = chunk.usage + + # usage = self._prepare_usage_data(usage_data) + conversation.add_message(AgentMessage(content=message_content)) + + def batch( + self, + conversations: List[Conversation], + temperature: float = 0.7, + max_tokens: int = 256, + top_p: float = 1.0, + enable_json: bool = False, + stop: Optional[List[str]] = None, + ) -> List[Conversation]: + """ + Processes a batch of conversations and generates responses for each sequentially. + + Args: + conversations (List[Conversation]): List of conversations to process. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for each response. + top_p (float): Cumulative probability for nucleus sampling. + enable_json (bool): Whether to format the response as JSON. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Returns: + List[Conversation]: List of updated conversations with model responses. + """ + return [ + self.predict( + conv, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + enable_json=enable_json, + stop=stop, + ) + for conv in conversations + ] + + async def abatch( + self, + conversations: List[Conversation], + temperature: float = 0.7, + max_tokens: int = 256, + top_p: float = 1.0, + enable_json: bool = False, + stop: Optional[List[str]] = None, + max_concurrent=5, + ) -> List[Conversation]: + """ + Async method for processing a batch of conversations concurrently. + + Args: + conversations (List[Conversation]): List of conversations to process. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for each response. + top_p (float): Cumulative probability for nucleus sampling. + enable_json (bool): Whether to format the response as JSON. + stop (Optional[List[str]]): List of stop sequences for response termination. + max_concurrent (int): Maximum number of concurrent requests. + + Returns: + List[Conversation]: List of updated conversations with model responses. + """ + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_conversation(conv) -> str | AsyncGenerator[str, None]: + async with semaphore: + return await self.apredict( + conv, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + enable_json=enable_json, + stop=stop, + ) + + tasks = [process_conversation(conv) for conv in conversations] + return await asyncio.gather(*tasks) diff --git a/pkgs/swarmauri-partner-clients/llms/GroqModel/__init__.py b/pkgs/swarmauri-partner-clients/llms/GroqModel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/swarmauri/pyproject.toml b/pkgs/swarmauri/pyproject.toml index 745f1ef6c..ff4c3300a 100644 --- a/pkgs/swarmauri/pyproject.toml +++ b/pkgs/swarmauri/pyproject.toml @@ -24,7 +24,6 @@ cohere = "^5.11.0" gensim = "==4.3.3" google-generativeai = "^0.8.3" gradio = "==5.1.0" -groq = "^0.11.0" joblib = "^1.4.0" mistralai = "^1.1.0" nltk = "^3.9.1" diff --git a/pkgs/swarmauri/swarmauri/llms/concrete/GroqModel.py b/pkgs/swarmauri/swarmauri/llms/concrete/GroqModel.py index e511e1d20..f2abbe2ac 100644 --- a/pkgs/swarmauri/swarmauri/llms/concrete/GroqModel.py +++ b/pkgs/swarmauri/swarmauri/llms/concrete/GroqModel.py @@ -1,11 +1,12 @@ import asyncio import json +from pydantic import PrivateAttr +import httpx +import requests from swarmauri.conversations.concrete.Conversation import Conversation -from typing import List, Optional, Dict, Literal, Any, Union, AsyncGenerator +from typing import List, Optional, Dict, Literal, Any, AsyncGenerator, Generator -from groq import Groq, AsyncGroq from swarmauri_core.typing import SubclassUnion - from swarmauri.messages.base.MessageBase import MessageBase from swarmauri.messages.concrete.AgentMessage import AgentMessage from swarmauri.llms.base.LLMBase import LLMBase @@ -14,7 +15,20 @@ class GroqModel(LLMBase): - """Provider resources: https://console.groq.com/docs/models""" + """ + GroqModel class for interacting with the Groq language models API. This class + provides synchronous and asynchronous methods to send conversation data to the + model, receive predictions, and stream responses. + + Attributes: + api_key (str): API key for authenticating requests to the Groq API. + allowed_models (List[str]): List of allowed model names that can be used. + name (str): The default model name to use for predictions. + type (Literal["GroqModel"]): The type identifier for this class. + + + Allowed Models resources: https://console.groq.com/docs/models + """ api_key: str allowed_models: List[str] = [ @@ -33,16 +47,25 @@ class GroqModel(LLMBase): "llama3-groq-8b-8192-tool-use-preview", "llava-v1.5-7b-4096-preview", "mixtral-8x7b-32768", - # multimodal models - "llama-3.2-11b-vision-preview", ] name: str = "gemma-7b-it" type: Literal["GroqModel"] = "GroqModel" + _api_url: str = PrivateAttr("https://api.groq.com/openai/v1/chat/completions") def _format_messages( self, messages: List[SubclassUnion[MessageBase]], ) -> List[Dict[str, Any]]: + """ + Formats conversation messages into the structure expected by the API. + + Args: + messages (List[MessageBase]): List of message objects from the conversation history. + + Returns: + List[Dict[str, Any]]: List of formatted message dictionaries. + """ + formatted_messages = [] for message in messages: formatted_message = message.model_dump( @@ -58,169 +81,243 @@ def _format_messages( formatted_messages.append(formatted_message) return formatted_messages - def _prepare_usage_data( - self, - usage_data, - ): + def _prepare_usage_data(self, usage_data) -> UsageData: """ - Prepares and extracts usage data and response timing. + Prepares and validates usage data received from the API response. + + Args: + usage_data (dict): Raw usage data from the API response. + + Returns: + UsageData: Validated usage data instance. """ + return UsageData.model_validate(usage_data) - usage = UsageData.model_validate(usage_data) - return usage + def _make_request(self, data: Dict[str, Any]) -> Dict[str, Any]: + """ + Sends a synchronous HTTP POST request to the API and retrieves the response. + + Args: + data (dict): Payload data to be sent in the API request. + + Returns: + dict: Parsed JSON response from the API. + """ + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + } + response = requests.post(self._api_url, headers=headers, json=data) + response.raise_for_status() # Raise an error for HTTP issues + return response.json() def predict( self, - conversation, + conversation: Conversation, temperature: float = 0.7, max_tokens: int = 256, top_p: float = 1.0, enable_json: bool = False, stop: Optional[List[str]] = None, - ): - + ) -> Conversation: + """ + Generates a response from the model based on the given conversation. + + Args: + conversation (Conversation): Conversation object with message history. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for the model's response. + top_p (float): Cumulative probability for nucleus sampling. + enable_json (bool): Whether to format the response as JSON. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Returns: + Conversation: Updated conversation with the model's response. + """ formatted_messages = self._format_messages(conversation.history) - response_format = {"type": "json_object"} if enable_json else None - - kwargs = { + data = { "model": self.name, "messages": formatted_messages, "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, - "response_format": response_format, "stop": stop or [], } + if enable_json: + data["response_format"] = "json_object" - client = Groq(api_key=self.api_key) - response = client.chat.completions.create(**kwargs) - - result = json.loads(response.model_dump_json()) + result = self._make_request(data) message_content = result["choices"][0]["message"]["content"] usage_data = result.get("usage", {}) usage = self._prepare_usage_data(usage_data) - conversation.add_message(AgentMessage(content=message_content, usage=usage)) return conversation async def apredict( self, - conversation, + conversation: Conversation, temperature: float = 0.7, max_tokens: int = 256, top_p: float = 1.0, enable_json: bool = False, stop: Optional[List[str]] = None, - ) -> Union[str, AsyncGenerator[str, None]]: - + ) -> Conversation: + """ + Async method to generate a response from the model based on the given conversation. + + Args: + conversation (Conversation): Conversation object with message history. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for the model's response. + top_p (float): Cumulative probability for nucleus sampling. + enable_json (bool): Whether to format the response as JSON. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Returns: + Conversation: Updated conversation with the model's response. + """ formatted_messages = self._format_messages(conversation.history) - response_format = {"type": "json_object"} if enable_json else None - - kwargs = { + data = { "model": self.name, "messages": formatted_messages, "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, - "response_format": response_format, "stop": stop or [], } + if enable_json: + data["response_format"] = "json_object" - client = AsyncGroq(api_key=self.api_key) - response = await client.chat.completions.create(**kwargs) - - result = json.loads(response.model_dump_json()) + # Use asyncio's to_thread to call synchronous code in an async context + result = await asyncio.to_thread(self._make_request, data) message_content = result["choices"][0]["message"]["content"] usage_data = result.get("usage", {}) usage = self._prepare_usage_data(usage_data) - conversation.add_message(AgentMessage(content=message_content, usage=usage)) return conversation def stream( self, - conversation, + conversation: Conversation, temperature: float = 0.7, max_tokens: int = 256, top_p: float = 1.0, enable_json: bool = False, stop: Optional[List[str]] = None, - ) -> Union[str, AsyncGenerator[str, None]]: + ) -> Generator[str, None, None]: + """ + Streams response text from the model in real-time. + + Args: + conversation (Conversation): Conversation object with message history. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for the model's response. + top_p (float): Cumulative probability for nucleus sampling. + enable_json (bool): Whether to format the response as JSON. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Yields: + str: Partial response content from the model. + """ formatted_messages = self._format_messages(conversation.history) - response_format = {"type": "json_object"} if enable_json else None - - kwargs = { + data = { "model": self.name, "messages": formatted_messages, "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, - "response_format": response_format, "stream": True, "stop": stop or [], - # "stream_options": {"include_usage": True}, } + if enable_json: + data["response_format"] = "json_object" - client = Groq(api_key=self.api_key) - stream = client.chat.completions.create(**kwargs) - message_content = "" - # usage_data = {} - - for chunk in stream: - if chunk.choices and chunk.choices[0].delta.content: - message_content += chunk.choices[0].delta.content - yield chunk.choices[0].delta.content - - # if hasattr(chunk, "usage") and chunk.usage is not None: - # usage_data = chunk.usage - - # usage = self._prepare_usage_data(usage_data) + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + } - conversation.add_message(AgentMessage(content=message_content)) + with requests.post( + self._api_url, headers=headers, json=data, stream=True + ) as response: + response.raise_for_status() + message_content = "" + for line in response.iter_lines(decode_unicode=True): + json_str = line.replace('data: ', '') + try: + if json_str: + chunk = json.loads(json_str) + if chunk["choices"][0]["delta"]: + delta = chunk["choices"][0]["delta"]["content"] + message_content += delta + yield delta + except json.JSONDecodeError: + pass + + conversation.add_message(AgentMessage(content=message_content)) async def astream( self, - conversation, + conversation: Conversation, temperature: float = 0.7, max_tokens: int = 256, top_p: float = 1.0, enable_json: bool = False, stop: Optional[List[str]] = None, ) -> AsyncGenerator[str, None]: + """ + Async generator that streams response text from the model in real-time. + + Args: + conversation (Conversation): Conversation object with message history. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for the model's response. + top_p (float): Cumulative probability for nucleus sampling. + enable_json (bool): Whether to format the response as JSON. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Yields: + str: Partial response content from the model. + """ formatted_messages = self._format_messages(conversation.history) - response_format = {"type": "json_object"} if enable_json else None - - kwargs = { + data = { "model": self.name, "messages": formatted_messages, "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, - "response_format": response_format, - "stop": stop or [], "stream": True, - # "stream_options": {"include_usage": True}, + "stop": stop or [], } + if enable_json: + data["response_format"] = "json_object" - client = AsyncGroq(api_key=self.api_key) - stream = await client.chat.completions.create(**kwargs) - - message_content = "" - # usage_data = {} - - async for chunk in stream: - if chunk.choices and chunk.choices[0].delta.content: - message_content += chunk.choices[0].delta.content - yield chunk.choices[0].delta.content + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + } - # if hasattr(chunk, "usage") and chunk.usage is not None: - # usage_data = chunk.usage + async with httpx.AsyncClient() as client: + response = await client.post(self._api_url, headers=headers, json=data, timeout=None) + + response.raise_for_status() + message_content = "" + async for line in response.aiter_lines(): + json_str = line.replace('data: ', '') + try: + if json_str: + chunk = json.loads(json_str) + if chunk["choices"][0]["delta"]: + delta = chunk["choices"][0]["delta"]["content"] + message_content += delta + yield delta + except json.JSONDecodeError: + pass - # usage = self._prepare_usage_data(usage_data) conversation.add_message(AgentMessage(content=message_content)) def batch( @@ -231,19 +328,33 @@ def batch( top_p: float = 1.0, enable_json: bool = False, stop: Optional[List[str]] = None, - ) -> List: - """Synchronously process multiple conversations""" - return [ - self.predict( - conv, + ) -> List[Conversation]: + """ + Processes a batch of conversations and generates responses for each sequentially. + + Args: + conversations (List[Conversation]): List of conversations to process. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for each response. + top_p (float): Cumulative probability for nucleus sampling. + enable_json (bool): Whether to format the response as JSON. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Returns: + List[Conversation]: List of updated conversations with model responses. + """ + results = [] + for conversation in conversations: + result_conversation = self.predict( + conversation, temperature=temperature, max_tokens=max_tokens, top_p=top_p, enable_json=enable_json, stop=stop, ) - for conv in conversations - ] + results.append(result_conversation) + return results async def abatch( self, @@ -254,11 +365,25 @@ async def abatch( enable_json: bool = False, stop: Optional[List[str]] = None, max_concurrent=5, - ) -> List: - """Process multiple conversations in parallel with controlled concurrency""" + ) -> List[Conversation]: + """ + Async method for processing a batch of conversations concurrently. + + Args: + conversations (List[Conversation]): List of conversations to process. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for each response. + top_p (float): Cumulative probability for nucleus sampling. + enable_json (bool): Whether to format the response as JSON. + stop (Optional[List[str]]): List of stop sequences for response termination. + max_concurrent (int): Maximum number of concurrent requests. + + Returns: + List[Conversation]: List of updated conversations with model responses. + """ semaphore = asyncio.Semaphore(max_concurrent) - async def process_conversation(conv): + async def process_conversation(conv: Conversation) -> Conversation: async with semaphore: return await self.apredict( conv, diff --git a/pkgs/swarmauri/tests/unit/llms/GroqModel_unit_test.py b/pkgs/swarmauri/tests/unit/llms/GroqModel_unit_test.py index 46b5141de..68bf236b1 100644 --- a/pkgs/swarmauri/tests/unit/llms/GroqModel_unit_test.py +++ b/pkgs/swarmauri/tests/unit/llms/GroqModel_unit_test.py @@ -1,6 +1,5 @@ import json import logging -from time import sleep import pytest import os @@ -56,14 +55,11 @@ def get_allowed_models(): "llama-guard-3-8b", ] - # multimodal models - multimodal_models = ["llama-3.2-11b-vision-preview"] - # Filter out the failing models allowed_models = [ model for model in llm.allowed_models - if model not in failing_llms and model not in multimodal_models + if model not in failing_llms ] return allowed_models @@ -97,7 +93,6 @@ def test_default_name(groq_model): @pytest.mark.parametrize("model_name", get_allowed_models()) @pytest.mark.unit def test_no_system_context(groq_model, model_name): - sleep(6) model = groq_model model.name = model_name conversation = Conversation() @@ -111,7 +106,7 @@ def test_no_system_context(groq_model, model_name): prediction = conversation.get_last().content usage_data = conversation.get_last().usage logging.info(usage_data) - assert type(prediction) == str + assert type(prediction) is str assert isinstance(usage_data, UsageData) @@ -136,7 +131,7 @@ def test_preamble_system_context(groq_model, model_name): prediction = conversation.get_last().content usage_data = conversation.get_last().usage logging.info(usage_data) - assert type(prediction) == str + assert type(prediction) is str assert isinstance(usage_data, UsageData) @@ -158,49 +153,11 @@ def test_llama_guard_3_8b_no_system_context(llama_guard_model): llama_guard_model.predict(conversation=conversation) prediction = conversation.get_last().content usage_data = conversation.get_last().usage - assert type(prediction) == str + assert type(prediction) is str assert isinstance(usage_data, UsageData) assert "safe" in prediction.lower() -@timeout(5) -@pytest.mark.parametrize( - "model_name, input_data", - [ - ( - "llama-3.2-11b-vision-preview", - [ - {"type": "text", "text": "What’s in this image?"}, - { - "type": "image_url", - "image_url": { - "url": f"{image_url}", - }, - }, - ], - ), - ], -) -@pytest.mark.unit -def test_multimodal_models_no_system_context(groq_model, model_name, input_data): - """ - Test case specifically for the multimodal models. - This models are designed process a wide variety of inputs, including text, images, and audio, - as prompts and convert those prompts into various outputs, not just the source type. - - """ - conversation = Conversation() - groq_model.name = model_name - - human_message = HumanMessage(content=input_data) - conversation.add_message(human_message) - - groq_model.predict(conversation=conversation) - prediction = conversation.get_last().content - logging.info(prediction) - assert isinstance(prediction, str) - - @timeout(5) @pytest.mark.parametrize("model_name", get_allowed_models()) @pytest.mark.unit @@ -215,6 +172,7 @@ def test_stream(groq_model, model_name): collected_tokens = [] for token in model.stream(conversation=conversation): + logging.info(token) assert isinstance(token, str) collected_tokens.append(token)