From d23f4c31cc4402515170bb28c30800a6746c603f Mon Sep 17 00:00:00 2001 From: 3rdSon Date: Fri, 11 Oct 2024 12:04:47 +0100 Subject: [PATCH] implemented async, stream, predict in anthropictoolmodel --- .../llms/concrete/AnthropicToolModel.py | 160 +++++++++++++++++ .../vector_stores/concrete/__init__.py | 3 +- .../unit/llms/AnthropicToolModel_unit_test.py | 167 ++++++++++++++++++ 3 files changed, 329 insertions(+), 1 deletion(-) create mode 100644 pkgs/swarmauri/tests/unit/llms/AnthropicToolModel_unit_test.py diff --git a/pkgs/swarmauri/swarmauri/llms/concrete/AnthropicToolModel.py b/pkgs/swarmauri/swarmauri/llms/concrete/AnthropicToolModel.py index 6b375867e..397404ad8 100644 --- a/pkgs/swarmauri/swarmauri/llms/concrete/AnthropicToolModel.py +++ b/pkgs/swarmauri/swarmauri/llms/concrete/AnthropicToolModel.py @@ -1,4 +1,6 @@ import json +import asyncio +from typing import AsyncIterator, Iterator from typing import List, Dict, Literal, Any import logging import anthropic @@ -91,3 +93,161 @@ def predict( conversation.add_message(agent_message) logging.info(f"conversation: {conversation}") return conversation + + async def apredict( + self, + conversation, + toolkit=None, + tool_choice=None, + temperature=0.7, + max_tokens=1024, + ): + client = anthropic.Anthropic(api_key=self.api_key) + formatted_messages = self._format_messages(conversation.history) + + if toolkit and not tool_choice: + tool_choice = {"type": "auto"} + + tool_response = await client.messages.create( + model=self.name, + messages=formatted_messages, + temperature=temperature, + max_tokens=max_tokens, + tools=self._schema_convert_tools(toolkit.tools) if toolkit else None, + tool_choice=tool_choice, + ) + + logging.info(f"tool_response: {tool_response}") + tool_text_response = None + if tool_response.content[0].type == "text": + tool_text_response = tool_response.content[0].text + logging.info(f"tool_text_response: {tool_text_response}") + + func_result = None + for tool_call in tool_response.content: + if tool_call.type == "tool_use": + func_name = tool_call.name + func_call = toolkit.get_tool_by_name(func_name) + func_args = tool_call.input + func_result = func_call(**func_args) + + if tool_text_response: + agent_response = f"{tool_text_response} {func_result}" + else: + agent_response = f"{func_result}" + + agent_message = AgentMessage(content=agent_response) + conversation.add_message(agent_message) + return conversation + + def stream( + self, + conversation, + toolkit=None, + tool_choice=None, + temperature=0.7, + max_tokens=1024, + ) -> Iterator[str]: + client = anthropic.Anthropic(api_key=self.api_key) + formatted_messages = self._format_messages(conversation.history) + + if toolkit and not tool_choice: + tool_choice = {"type": "auto"} + + stream = client.messages.create( + model=self.name, + messages=formatted_messages, + temperature=temperature, + max_tokens=max_tokens, + tools=self._schema_convert_tools(toolkit.tools) if toolkit else None, + tool_choice=tool_choice, + stream=True, + ) + + collected_content = [] + for chunk in stream: + if chunk.type == "content_block_delta": + if chunk.delta.type == "text": + collected_content.append(chunk.delta.text) + yield chunk.delta.text + + full_content = "".join(collected_content) + conversation.add_message(AgentMessage(content=full_content)) + + async def astream( + self, + conversation, + toolkit=None, + tool_choice=None, + temperature=0.7, + max_tokens=1024, + ) -> AsyncIterator[str]: + client = anthropic.Anthropic(api_key=self.api_key) + formatted_messages = self._format_messages(conversation.history) + + if toolkit and not tool_choice: + tool_choice = {"type": "auto"} + + stream = await client.messages.create( + model=self.name, + messages=formatted_messages, + temperature=temperature, + max_tokens=max_tokens, + tools=self._schema_convert_tools(toolkit.tools) if toolkit else None, + tool_choice=tool_choice, + stream=True, + ) + + collected_content = [] + async for chunk in stream: + if chunk.type == "content_block_delta": + if chunk.delta.type == "text": + collected_content.append(chunk.delta.text) + yield chunk.delta.text + + full_content = "".join(collected_content) + conversation.add_message(AgentMessage(content=full_content)) + + def batch( + self, + conversations: List, + toolkit=None, + tool_choice=None, + temperature=0.7, + max_tokens=1024, + ) -> List: + results = [] + for conv in conversations: + result = self.predict( + conversation=conv, + toolkit=toolkit, + tool_choice=tool_choice, + temperature=temperature, + max_tokens=max_tokens, + ) + results.append(result) + return results + + async def abatch( + self, + conversations: List, + toolkit=None, + tool_choice=None, + temperature=0.7, + max_tokens=1024, + max_concurrent=5, + ) -> List: + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_conversation(conv): + async with semaphore: + return await self.apredict( + conv, + toolkit=toolkit, + tool_choice=tool_choice, + temperature=temperature, + max_tokens=max_tokens, + ) + + tasks = [process_conversation(conv) for conv in conversations] + return await asyncio.gather(*tasks) diff --git a/pkgs/swarmauri/swarmauri/vector_stores/concrete/__init__.py b/pkgs/swarmauri/swarmauri/vector_stores/concrete/__init__.py index b48a7de7a..2c955c2d0 100644 --- a/pkgs/swarmauri/swarmauri/vector_stores/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/vector_stores/concrete/__init__.py @@ -2,6 +2,7 @@ from swarmauri.vector_stores.concrete.Doc2VecVectorStore import Doc2VecVectorStore from swarmauri.vector_stores.concrete.MlmVectorStore import MlmVectorStore -from swarmauri.vector_stores.concrete.SpatialDocVectorStore import SpatialDocVectorStore + +# from swarmauri.vector_stores.concrete.SpatialDocVectorStore import SpatialDocVectorStore from swarmauri.vector_stores.concrete.SqliteVectorStore import SqliteVectorStore from swarmauri.vector_stores.concrete.TfidfVectorStore import TfidfVectorStore diff --git a/pkgs/swarmauri/tests/unit/llms/AnthropicToolModel_unit_test.py b/pkgs/swarmauri/tests/unit/llms/AnthropicToolModel_unit_test.py new file mode 100644 index 000000000..5102b77b9 --- /dev/null +++ b/pkgs/swarmauri/tests/unit/llms/AnthropicToolModel_unit_test.py @@ -0,0 +1,167 @@ +import logging +import pytest +import os +from swarmauri.llms.concrete.AnthropicToolModel import AnthropicToolModel as LLM +from swarmauri.conversations.concrete.Conversation import Conversation +from swarmauri.messages.concrete import HumanMessage +from swarmauri.tools.concrete.AdditionTool import AdditionTool +from swarmauri.toolkits.concrete.Toolkit import Toolkit +from swarmauri.agents.concrete.ToolAgent import ToolAgent +from dotenv import load_dotenv + +load_dotenv() + +API_KEY = os.getenv("ANTHROPIC_API_KEY") + + +@pytest.fixture(scope="module") +def anthropic_tool_model(): + if not API_KEY: + pytest.skip("Skipping due to environment variable not set") + llm = LLM(api_key=API_KEY) + return llm + + +def get_allowed_models(): + if not API_KEY: + return [] + llm = LLM(api_key=API_KEY) + return llm.allowed_models + + +@pytest.fixture(scope="module") +def toolkit(): + toolkit = Toolkit() + tool = AdditionTool() + toolkit.add_tool(tool) + return toolkit + + +@pytest.fixture(scope="module") +def conversation(): + conversation = Conversation() + input_data = {"type": "text", "text": "Add 512+671"} + human_message = HumanMessage(content=[input_data]) + conversation.add_message(human_message) + return conversation + + +@pytest.mark.unit +def test_ubc_resource(anthropic_tool_model): + assert anthropic_tool_model.resource == "LLM" + + +@pytest.mark.unit +def test_ubc_type(anthropic_tool_model): + assert anthropic_tool_model.type == "AnthropicToolModel" + + +@pytest.mark.unit +def test_serialization(anthropic_tool_model): + assert ( + anthropic_tool_model.id + == LLM.model_validate_json(anthropic_tool_model.model_dump_json()).id + ) + + +@pytest.mark.unit +def test_default_name(anthropic_tool_model): + assert anthropic_tool_model.name == "claude-3-haiku-20240307" + + +@pytest.mark.unit +@pytest.mark.parametrize("model_name", get_allowed_models()) +def test_agent_exec(anthropic_tool_model, toolkit, conversation, model_name): + anthropic_tool_model.name = model_name + agent = ToolAgent( + llm=anthropic_tool_model, conversation=conversation, toolkit=toolkit + ) + result = agent.exec("Add 512+671") + assert isinstance(result, str) + + +@pytest.mark.unit +@pytest.mark.parametrize("model_name", get_allowed_models()) +def test_predict(anthropic_tool_model, toolkit, conversation, model_name): + anthropic_tool_model.name = model_name + conversation = anthropic_tool_model.predict( + conversation=conversation, toolkit=toolkit + ) + logging.info(conversation.get_last().content) + assert isinstance(conversation.get_last().content, str) + + +@pytest.mark.unit +@pytest.mark.parametrize("model_name", get_allowed_models()) +def test_stream(anthropic_tool_model, toolkit, conversation, model_name): + anthropic_tool_model.name = model_name + collected_tokens = [] + for token in anthropic_tool_model.stream( + conversation=conversation, toolkit=toolkit + ): + assert isinstance(token, str) + collected_tokens.append(token) + full_response = "".join(collected_tokens) + assert len(full_response) > 0 + assert conversation.get_last().content == full_response + + +@pytest.mark.unit +@pytest.mark.parametrize("model_name", get_allowed_models()) +def test_batch(anthropic_tool_model, toolkit, model_name): + anthropic_tool_model.name = model_name + conversations = [] + for prompt in ["20+20", "100+50", "500+500"]: + conv = Conversation() + conv.add_message(HumanMessage(content=[{"type": "text", "text": prompt}])) + conversations.append(conv) + results = anthropic_tool_model.batch(conversations=conversations, toolkit=toolkit) + assert len(results) == len(conversations) + for result in results: + assert isinstance(result.get_last().content, str) + + +@pytest.mark.unit +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize("model_name", get_allowed_models()) +async def test_apredict(anthropic_tool_model, toolkit, conversation, model_name): + anthropic_tool_model.name = model_name + result = await anthropic_tool_model.apredict( + conversation=conversation, toolkit=toolkit + ) + prediction = result.get_last().content + assert isinstance(prediction, str) + + +@pytest.mark.unit +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize("model_name", get_allowed_models()) +async def test_astream(anthropic_tool_model, toolkit, conversation, model_name): + anthropic_tool_model.name = model_name + collected_tokens = [] + async for token in anthropic_tool_model.astream( + conversation=conversation, toolkit=toolkit + ): + assert isinstance(token, str) + collected_tokens.append(token) + full_response = "".join(collected_tokens) + assert len(full_response) > 0 + assert conversation.get_last().content == full_response + + +@pytest.mark.unit +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize("model_name", get_allowed_models()) +async def test_abatch(anthropic_tool_model, toolkit, model_name): + anthropic_tool_model.name = model_name + conversations = [] + for prompt in ["20+20", "100+50", "500+500"]: + conv = Conversation() + conv.add_message(HumanMessage(content=[{"type": "text", "text": prompt}])) + conversations.append(conv) + results = await anthropic_tool_model.abatch( + conversations=conversations, toolkit=toolkit + ) + assert len(results) == len(conversations) + for result in results: + assert isinstance(result.get_last().content, str)