Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

swarm - implemented async, batch, abatch, on MistralModel.py #627

Merged
merged 1 commit into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 146 additions & 2 deletions pkgs/swarmauri/swarmauri/llms/concrete/MistralModel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import json
from typing import List, Literal, Dict
from mistralai import Mistral
from swarmauri.conversations.concrete import Conversation
from swarmauri_core.typing import SubclassUnion

from swarmauri.messages.base.MessageBase import MessageBase
Expand Down Expand Up @@ -45,10 +47,9 @@ def predict(
enable_json: bool = False,
safe_prompt: bool = False,
):

formatted_messages = self._format_messages(conversation.history)

client = Mistral(api_key=self.api_key)

if enable_json:
response = client.chat.complete(
model=self.name,
Expand All @@ -74,3 +75,146 @@ def predict(
conversation.add_message(AgentMessage(content=message_content))

return conversation

async def apredict(
self,
conversation,
temperature: int = 0.7,
max_tokens: int = 256,
top_p: int = 1,
enable_json: bool = False,
safe_prompt: bool = False,
):
formatted_messages = self._format_messages(conversation.history)
client = Mistral(api_key=self.api_key)

if enable_json:
response = await client.chat.complete_async(
model=self.name,
messages=formatted_messages,
temperature=temperature,
response_format={"type": "json_object"},
max_tokens=max_tokens,
top_p=top_p,
safe_prompt=safe_prompt,
)
else:
response = await client.chat.complete_async(
model=self.name,
messages=formatted_messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
safe_prompt=safe_prompt,
)

result = json.loads(response.json())
message_content = result["choices"][0]["message"]["content"]
conversation.add_message(AgentMessage(content=message_content))

return conversation

def stream(
self,
conversation,
temperature: int = 0.7,
max_tokens: int = 256,
top_p: int = 1,
safe_prompt: bool = False,
):
formatted_messages = self._format_messages(conversation.history)
client = Mistral(api_key=self.api_key)

stream_response = client.chat.stream(
model=self.name,
messages=formatted_messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
safe_prompt=safe_prompt,
)
message_content = ""

for chunk in stream_response:
if chunk.data.choices[0].delta.content:
message_content += chunk.data.choices[0].delta.content
yield chunk.data.choices[0].delta.content

conversation.add_message(AgentMessage(content=message_content))

async def astream(
self,
conversation,
temperature: int = 0.7,
max_tokens: int = 256,
top_p: int = 1,
safe_prompt: bool = False,
):
formatted_messages = self._format_messages(conversation.history)
client = Mistral(api_key=self.api_key)

stream_response = await client.chat.stream_async(
model=self.name,
messages=formatted_messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
safe_prompt=safe_prompt,
)
message_content = ""

for chunk in stream_response:
if chunk.data.choices[0].delta.content:
message_content += chunk.data.choices[0].delta.content
yield chunk.data.choices[0].delta.content

conversation.add_message(AgentMessage(content=message_content))

def batch(
self,
conversations: List[Conversation],
temperature: float = 0.7,
max_tokens: int = 256,
top_p: int = 1,
enable_json: bool = False,
safe_prompt: bool = False,
) -> List:
"""Synchronously process multiple conversations"""
return [
self.predict(
conv,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
enable_json=enable_json,
safe_prompt=safe_prompt,
)
for conv in conversations
]

async def abatch(
self,
conversations: List[Conversation],
temperature: float = 0.7,
max_tokens: int = 256,
top_p: int = 1,
enable_json: bool = False,
safe_prompt: bool = False,
max_concurrent: int = 5,
) -> List:
"""Process multiple conversations in parallel with controlled concurrency"""
semaphore = asyncio.Semaphore(max_concurrent)

async def process_conversation(conv):
async with semaphore:
return await self.apredict(
conv,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
enable_json=enable_json,
safe_prompt=safe_prompt,
)

tasks = [process_conversation(conv) for conv in conversations]
return await asyncio.gather(*tasks)
97 changes: 97 additions & 0 deletions pkgs/swarmauri/tests/unit/llms/MistralModel_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,100 @@ def test_preamble_system_context(mistral_model, model_name):

assert type(prediction) == str
assert "Jeff" in prediction


@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.unit
def test_stream(mistral_model, model_name):
model = mistral_model
model.name = model_name
conversation = Conversation()

input_data = "Write a short story about a cat."
human_message = HumanMessage(content=input_data)
conversation.add_message(human_message)

collected_tokens = []
for token in model.stream(conversation=conversation):
assert isinstance(token, str)
collected_tokens.append(token)

full_response = "".join(collected_tokens)
assert len(full_response) > 0
assert conversation.get_last().content == full_response


@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.unit
def test_batch(mistral_model, model_name):
model = mistral_model
model.name = model_name

conversations = []
for prompt in ["Hello", "Hi there", "Good morning"]:
conv = Conversation()
conv.add_message(HumanMessage(content=prompt))
conversations.append(conv)

results = model.batch(conversations=conversations)
assert len(results) == len(conversations)
for result in results:
assert isinstance(result.get_last().content, str)


@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.unit
async def test_apredict(mistral_model, model_name):
model = mistral_model
model.name = model_name
conversation = Conversation()

input_data = "Hello"
human_message = HumanMessage(content=input_data)
conversation.add_message(human_message)

result = await model.apredict(conversation=conversation)
prediction = result.get_last().content
assert isinstance(prediction, str)


@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.unit
async def test_astream(mistral_model, model_name):
model = mistral_model
model.name = model_name
conversation = Conversation()

input_data = "Write a short story about a dog."
human_message = HumanMessage(content=input_data)
conversation.add_message(human_message)

collected_tokens = []
async for token in model.astream(conversation=conversation):
assert isinstance(token, str)
collected_tokens.append(token)

full_response = "".join(collected_tokens)
assert len(full_response) > 0
assert conversation.get_last().content == full_response


@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.unit
async def test_abatch(mistral_model, model_name):
model = mistral_model
model.name = model_name

conversations = []
for prompt in ["Hello", "Hi there", "Good morning"]:
conv = Conversation()
conv.add_message(HumanMessage(content=prompt))
conversations.append(conv)

results = await model.abatch(conversations=conversations)
assert len(results) == len(conversations)
for result in results:
assert isinstance(result.get_last().content, str)