Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

swarm - Add async stream and batch processing to GeminiToolModel.py #637

Merged
merged 2 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 178 additions & 5 deletions pkgs/swarmauri/swarmauri/llms/concrete/GeminiToolModel.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import logging
import json
from typing import List, Literal, Dict, Any
import google.generativeai as genai
from google.generativeai.protos import FunctionDeclaration
from swarmauri.conversations.concrete import Conversation
from swarmauri_core.typing import SubclassUnion

from swarmauri.messages.base.MessageBase import MessageBase
Expand All @@ -14,6 +14,8 @@
)
import google.generativeai as genai

from pkgs.swarmauri.swarmauri.toolkits.concrete.Toolkit import Toolkit


class GeminiToolModel(LLMBase):
"""
Expand Down Expand Up @@ -112,9 +114,8 @@ def predict(self, conversation, toolkit=None, temperature=0.7, max_tokens=256):
formatted_messages = self._format_messages(conversation.history)
tools = self._schema_convert_tools(toolkit.tools)


logging.info(f'formatted_messages: {formatted_messages}')
logging.info(f'tools: {tools}')
logging.info(f"formatted_messages: {formatted_messages}")
logging.info(f"tools: {tools}")

tool_response = client.generate_content(
formatted_messages,
Expand Down Expand Up @@ -168,3 +169,175 @@ def predict(self, conversation, toolkit=None, temperature=0.7, max_tokens=256):

logging.info(f"conversation: {conversation}")
return conversation

async def apredict(
self, conversation, toolkit=None, temperature=0.7, max_tokens=256
):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None, self.predict, conversation, toolkit, temperature, max_tokens
)

def stream(self, conversation, toolkit=None, temperature=0.7, max_tokens=256):
genai.configure(api_key=self.api_key)
generation_config = {
"temperature": temperature,
"top_p": 0.95,
"top_k": 0,
"max_output_tokens": max_tokens,
}

safety_settings = [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
]

tool_config = {
"function_calling_config": {"mode": "ANY"},
}

client = genai.GenerativeModel(
model_name=self.name,
safety_settings=safety_settings,
generation_config=generation_config,
tool_config=tool_config,
)

formatted_messages = self._format_messages(conversation.history)
tools = self._schema_convert_tools(toolkit.tools)

logging.info(f"formatted_messages: {formatted_messages}")
logging.info(f"tools: {tools}")

tool_response = client.generate_content(
formatted_messages,
tools=tools,
)
logging.info(f"tool_response: {tool_response}")

formatted_messages.append(tool_response.candidates[0].content)

logging.info(
f"tool_response.candidates[0].content: {tool_response.candidates[0].content}"
)

tool_calls = tool_response.candidates[0].content.parts

tool_results = {}
for tool_call in tool_calls:
func_name = tool_call.function_call.name
func_args = tool_call.function_call.args
logging.info(f"func_name: {func_name}")
logging.info(f"func_args: {func_args}")

func_call = toolkit.get_tool_by_name(func_name)
func_result = func_call(**func_args)
logging.info(f"func_result: {func_result}")
tool_results[func_name] = func_result

formatted_messages.append(
genai.protos.Content(
role="function",
parts=[
genai.protos.Part(
function_response=genai.protos.FunctionResponse(
name=fn,
response={
"result": val, # Return the API response to Gemini
},
)
)
for fn, val in tool_results.items()
],
)
)

logging.info(f"formatted_messages: {formatted_messages}")

stream_response = client.generate_content(formatted_messages, stream=True)

full_response = ""
for chunk in stream_response:
chunk_text = chunk.text
full_response += chunk_text
yield chunk_text

logging.info(f"agent_response: {full_response}")
conversation.add_message(AgentMessage(content=full_response))

async def astream(
self, conversation, toolkit=None, temperature=0.7, max_tokens=256
):
loop = asyncio.get_event_loop()
stream_gen = self.stream(conversation, toolkit, temperature, max_tokens)

def safe_next(gen):
try:
return next(gen), False
except StopIteration:
return None, True

while True:
try:
chunk, done = await loop.run_in_executor(None, safe_next, stream_gen)
if done:
break
yield chunk
except Exception as e:
print(f"Error in astream: {e}")
break

def batch(
self,
conversations: List[Conversation],
toolkit: Toolkit = None,
temperature: float = 0.7,
max_tokens: int = 256,
) -> List:
"""Synchronously process multiple conversations"""
return [
self.predict(
conv,
toolkit=toolkit,
temperature=temperature,
max_tokens=max_tokens,
)
for conv in conversations
]

async def abatch(
self,
conversations: List[Conversation],
toolkit: Toolkit = None,
temperature: float = 0.7,
max_tokens: int = 256,
max_concurrent: int = 5,
) -> List:
"""Process multiple conversations in parallel with controlled concurrency"""
semaphore = asyncio.Semaphore(max_concurrent)

async def process_conversation(conv):
async with semaphore:
return await self.apredict(
conv,
toolkit=toolkit,
temperature=temperature,
max_tokens=max_tokens,
)

tasks = [process_conversation(conv) for conv in conversations]
return await asyncio.gather(*tasks)
136 changes: 131 additions & 5 deletions pkgs/swarmauri/tests/unit/llms/GeminiToolModel_unit_test.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,55 @@
import logging

import pytest
import os
from swarmauri.llms.concrete.GeminiToolModel import GeminiToolModel as LLM
from swarmauri.conversations.concrete.Conversation import Conversation
from swarmauri.messages.concrete import HumanMessage
from swarmauri.tools.concrete.AdditionTool import AdditionTool
from swarmauri.toolkits.concrete.Toolkit import Toolkit
from swarmauri.agents.concrete.ToolAgent import ToolAgent
from dotenv import load_dotenv

load_dotenv()

API_KEY = os.getenv("GEMINI_API_KEY")


@pytest.fixture(scope="module")
def gemini_tool_model():
API_KEY = os.getenv("GEMINI_API_KEY")
if not API_KEY:
pytest.skip("Skipping due to environment variable not set")
llm = LLM(api_key=API_KEY)
return llm


@pytest.fixture(scope="module")
def toolkit():
toolkit = Toolkit()
tool = AdditionTool()
toolkit.add_tool(tool)

return toolkit


@pytest.fixture(scope="module")
def conversation():
conversation = Conversation()

input_data = "Add 512+671"
human_message = HumanMessage(content=input_data)
conversation.add_message(human_message)

return conversation


def get_allowed_models():
if not API_KEY:
return []
llm = LLM(api_key=API_KEY)
return llm.allowed_models


@pytest.mark.unit
def test_ubc_resource(gemini_tool_model):
assert gemini_tool_model.resource == "LLM"
Expand All @@ -40,13 +74,105 @@ def test_default_name(gemini_tool_model):


@pytest.mark.unit
def test_agent_exec(gemini_tool_model):
@pytest.mark.parametrize("model_name", get_allowed_models())
def test_agent_exec(gemini_tool_model, toolkit, model_name):
gemini_tool_model.name = model_name
conversation = Conversation()
toolkit = Toolkit()
tool = AdditionTool()
toolkit.add_tool(tool)

# Use geminitool_model from the fixture
agent = ToolAgent(llm=gemini_tool_model, conversation=conversation, toolkit=toolkit)
result = agent.exec("Add 512+671")
assert type(result) == str


@pytest.mark.unit
@pytest.mark.parametrize("model_name", get_allowed_models())
def test_predict(gemini_tool_model, toolkit, conversation, model_name):
gemini_tool_model.name = model_name

conversation = gemini_tool_model.predict(conversation=conversation, toolkit=toolkit)

assert type(conversation.get_last().content) == str


@pytest.mark.unit
@pytest.mark.parametrize("model_name", get_allowed_models())
def test_stream(gemini_tool_model, toolkit, conversation, model_name):
gemini_tool_model.name = model_name

collected_tokens = []
for token in gemini_tool_model.stream(conversation=conversation, toolkit=toolkit):
assert isinstance(token, str)
collected_tokens.append(token)

full_response = "".join(collected_tokens)
assert len(full_response) > 0
assert conversation.get_last().content == full_response


@pytest.mark.unit
@pytest.mark.parametrize("model_name", get_allowed_models())
def test_batch(gemini_tool_model, toolkit, model_name):
gemini_tool_model.name = model_name

conversations = []
for prompt in ["20+20", "100+50", "500+500"]:
conv = Conversation()
conv.add_message(HumanMessage(content=prompt))
conversations.append(conv)

results = gemini_tool_model.batch(conversations=conversations, toolkit=toolkit)
assert len(results) == len(conversations)
for result in results:
assert isinstance(result.get_last().content, str)


@pytest.mark.unit
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.parametrize("model_name", get_allowed_models())
async def test_apredict(gemini_tool_model, toolkit, conversation, model_name):
gemini_tool_model.name = model_name

result = await gemini_tool_model.apredict(
conversation=conversation, toolkit=toolkit
)
prediction = result.get_last().content
assert isinstance(prediction, str)


@pytest.mark.unit
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.parametrize("model_name", get_allowed_models())
async def test_astream(gemini_tool_model, toolkit, conversation, model_name):
gemini_tool_model.name = model_name

collected_tokens = []
async for token in gemini_tool_model.astream(
conversation=conversation, toolkit=toolkit
):
assert isinstance(token, str)
collected_tokens.append(token)

full_response = "".join(collected_tokens)
assert len(full_response) > 0
assert conversation.get_last().content == full_response


@pytest.mark.unit
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.parametrize("model_name", get_allowed_models())
async def test_abatch(gemini_tool_model, toolkit, model_name):
gemini_tool_model.name = model_name

conversations = []
for prompt in ["20+20", "100+50", "500+500"]:
conv = Conversation()
conv.add_message(HumanMessage(content=prompt))
conversations.append(conv)

results = await gemini_tool_model.abatch(
conversations=conversations, toolkit=toolkit
)
assert len(results) == len(conversations)
for result in results:
assert isinstance(result.get_last().content, str)