Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implemented async, stream, predict in anthropictoolmodel #635

Merged
merged 2 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions pkgs/swarmauri/swarmauri/llms/concrete/AnthropicToolModel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import asyncio
from typing import AsyncIterator, Iterator
from typing import List, Dict, Literal, Any
import logging
import anthropic
Expand Down Expand Up @@ -91,3 +93,161 @@ def predict(
conversation.add_message(agent_message)
logging.info(f"conversation: {conversation}")
return conversation

async def apredict(
self,
conversation,
toolkit=None,
tool_choice=None,
temperature=0.7,
max_tokens=1024,
):
client = anthropic.Anthropic(api_key=self.api_key)
formatted_messages = self._format_messages(conversation.history)

if toolkit and not tool_choice:
tool_choice = {"type": "auto"}

tool_response = await client.messages.create(
model=self.name,
messages=formatted_messages,
temperature=temperature,
max_tokens=max_tokens,
tools=self._schema_convert_tools(toolkit.tools) if toolkit else None,
tool_choice=tool_choice,
)

logging.info(f"tool_response: {tool_response}")
tool_text_response = None
if tool_response.content[0].type == "text":
tool_text_response = tool_response.content[0].text
logging.info(f"tool_text_response: {tool_text_response}")

func_result = None
for tool_call in tool_response.content:
if tool_call.type == "tool_use":
func_name = tool_call.name
func_call = toolkit.get_tool_by_name(func_name)
func_args = tool_call.input
func_result = func_call(**func_args)

if tool_text_response:
agent_response = f"{tool_text_response} {func_result}"
else:
agent_response = f"{func_result}"

agent_message = AgentMessage(content=agent_response)
conversation.add_message(agent_message)
return conversation

def stream(
self,
conversation,
toolkit=None,
tool_choice=None,
temperature=0.7,
max_tokens=1024,
) -> Iterator[str]:
client = anthropic.Anthropic(api_key=self.api_key)
formatted_messages = self._format_messages(conversation.history)

if toolkit and not tool_choice:
tool_choice = {"type": "auto"}

stream = client.messages.create(
model=self.name,
messages=formatted_messages,
temperature=temperature,
max_tokens=max_tokens,
tools=self._schema_convert_tools(toolkit.tools) if toolkit else None,
tool_choice=tool_choice,
stream=True,
)

collected_content = []
for chunk in stream:
if chunk.type == "content_block_delta":
if chunk.delta.type == "text":
collected_content.append(chunk.delta.text)
yield chunk.delta.text

full_content = "".join(collected_content)
conversation.add_message(AgentMessage(content=full_content))

async def astream(
self,
conversation,
toolkit=None,
tool_choice=None,
temperature=0.7,
max_tokens=1024,
) -> AsyncIterator[str]:
client = anthropic.Anthropic(api_key=self.api_key)
formatted_messages = self._format_messages(conversation.history)

if toolkit and not tool_choice:
tool_choice = {"type": "auto"}

stream = await client.messages.create(
model=self.name,
messages=formatted_messages,
temperature=temperature,
max_tokens=max_tokens,
tools=self._schema_convert_tools(toolkit.tools) if toolkit else None,
tool_choice=tool_choice,
stream=True,
)

collected_content = []
async for chunk in stream:
if chunk.type == "content_block_delta":
if chunk.delta.type == "text":
collected_content.append(chunk.delta.text)
yield chunk.delta.text

full_content = "".join(collected_content)
conversation.add_message(AgentMessage(content=full_content))

def batch(
self,
conversations: List,
toolkit=None,
tool_choice=None,
temperature=0.7,
max_tokens=1024,
) -> List:
results = []
for conv in conversations:
result = self.predict(
conversation=conv,
toolkit=toolkit,
tool_choice=tool_choice,
temperature=temperature,
max_tokens=max_tokens,
)
results.append(result)
return results

async def abatch(
self,
conversations: List,
toolkit=None,
tool_choice=None,
temperature=0.7,
max_tokens=1024,
max_concurrent=5,
) -> List:
semaphore = asyncio.Semaphore(max_concurrent)

async def process_conversation(conv):
async with semaphore:
return await self.apredict(
conv,
toolkit=toolkit,
tool_choice=tool_choice,
temperature=temperature,
max_tokens=max_tokens,
)

tasks = [process_conversation(conv) for conv in conversations]
return await asyncio.gather(*tasks)
3 changes: 2 additions & 1 deletion pkgs/swarmauri/swarmauri/vector_stores/concrete/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from swarmauri.vector_stores.concrete.Doc2VecVectorStore import Doc2VecVectorStore
from swarmauri.vector_stores.concrete.MlmVectorStore import MlmVectorStore
from swarmauri.vector_stores.concrete.SpatialDocVectorStore import SpatialDocVectorStore

# from swarmauri.vector_stores.concrete.SpatialDocVectorStore import SpatialDocVectorStore
from swarmauri.vector_stores.concrete.SqliteVectorStore import SqliteVectorStore
from swarmauri.vector_stores.concrete.TfidfVectorStore import TfidfVectorStore
167 changes: 167 additions & 0 deletions pkgs/swarmauri/tests/unit/llms/AnthropicToolModel_unit_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import logging
import pytest
import os
from swarmauri.llms.concrete.AnthropicToolModel import AnthropicToolModel as LLM
from swarmauri.conversations.concrete.Conversation import Conversation
from swarmauri.messages.concrete import HumanMessage
from swarmauri.tools.concrete.AdditionTool import AdditionTool
from swarmauri.toolkits.concrete.Toolkit import Toolkit
from swarmauri.agents.concrete.ToolAgent import ToolAgent
from dotenv import load_dotenv

load_dotenv()

API_KEY = os.getenv("ANTHROPIC_API_KEY")


@pytest.fixture(scope="module")
def anthropic_tool_model():
if not API_KEY:
pytest.skip("Skipping due to environment variable not set")
llm = LLM(api_key=API_KEY)
return llm


def get_allowed_models():
if not API_KEY:
return []
llm = LLM(api_key=API_KEY)
return llm.allowed_models


@pytest.fixture(scope="module")
def toolkit():
toolkit = Toolkit()
tool = AdditionTool()
toolkit.add_tool(tool)
return toolkit


@pytest.fixture(scope="module")
def conversation():
conversation = Conversation()
input_data = {"type": "text", "text": "Add 512+671"}
human_message = HumanMessage(content=[input_data])
conversation.add_message(human_message)
return conversation


@pytest.mark.unit
def test_ubc_resource(anthropic_tool_model):
assert anthropic_tool_model.resource == "LLM"


@pytest.mark.unit
def test_ubc_type(anthropic_tool_model):
assert anthropic_tool_model.type == "AnthropicToolModel"


@pytest.mark.unit
def test_serialization(anthropic_tool_model):
assert (
anthropic_tool_model.id
== LLM.model_validate_json(anthropic_tool_model.model_dump_json()).id
)


@pytest.mark.unit
def test_default_name(anthropic_tool_model):
assert anthropic_tool_model.name == "claude-3-haiku-20240307"


@pytest.mark.unit
@pytest.mark.parametrize("model_name", get_allowed_models())
def test_agent_exec(anthropic_tool_model, toolkit, conversation, model_name):
anthropic_tool_model.name = model_name
agent = ToolAgent(
llm=anthropic_tool_model, conversation=conversation, toolkit=toolkit
)
result = agent.exec("Add 512+671")
assert isinstance(result, str)


@pytest.mark.unit
@pytest.mark.parametrize("model_name", get_allowed_models())
def test_predict(anthropic_tool_model, toolkit, conversation, model_name):
anthropic_tool_model.name = model_name
conversation = anthropic_tool_model.predict(
conversation=conversation, toolkit=toolkit
)
logging.info(conversation.get_last().content)
assert isinstance(conversation.get_last().content, str)


@pytest.mark.unit
@pytest.mark.parametrize("model_name", get_allowed_models())
def test_stream(anthropic_tool_model, toolkit, conversation, model_name):
anthropic_tool_model.name = model_name
collected_tokens = []
for token in anthropic_tool_model.stream(
conversation=conversation, toolkit=toolkit
):
assert isinstance(token, str)
collected_tokens.append(token)
full_response = "".join(collected_tokens)
assert len(full_response) > 0
assert conversation.get_last().content == full_response


@pytest.mark.unit
@pytest.mark.parametrize("model_name", get_allowed_models())
def test_batch(anthropic_tool_model, toolkit, model_name):
anthropic_tool_model.name = model_name
conversations = []
for prompt in ["20+20", "100+50", "500+500"]:
conv = Conversation()
conv.add_message(HumanMessage(content=[{"type": "text", "text": prompt}]))
conversations.append(conv)
results = anthropic_tool_model.batch(conversations=conversations, toolkit=toolkit)
assert len(results) == len(conversations)
for result in results:
assert isinstance(result.get_last().content, str)


@pytest.mark.unit
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.parametrize("model_name", get_allowed_models())
async def test_apredict(anthropic_tool_model, toolkit, conversation, model_name):
anthropic_tool_model.name = model_name
result = await anthropic_tool_model.apredict(
conversation=conversation, toolkit=toolkit
)
prediction = result.get_last().content
assert isinstance(prediction, str)


@pytest.mark.unit
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.parametrize("model_name", get_allowed_models())
async def test_astream(anthropic_tool_model, toolkit, conversation, model_name):
anthropic_tool_model.name = model_name
collected_tokens = []
async for token in anthropic_tool_model.astream(
conversation=conversation, toolkit=toolkit
):
assert isinstance(token, str)
collected_tokens.append(token)
full_response = "".join(collected_tokens)
assert len(full_response) > 0
assert conversation.get_last().content == full_response


@pytest.mark.unit
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.parametrize("model_name", get_allowed_models())
async def test_abatch(anthropic_tool_model, toolkit, model_name):
anthropic_tool_model.name = model_name
conversations = []
for prompt in ["20+20", "100+50", "500+500"]:
conv = Conversation()
conv.add_message(HumanMessage(content=[{"type": "text", "text": prompt}]))
conversations.append(conv)
results = await anthropic_tool_model.abatch(
conversations=conversations, toolkit=toolkit
)
assert len(results) == len(conversations)
for result in results:
assert isinstance(result.get_last().content, str)