Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

swarm - fixed cohere closed loop error #808

Merged
merged 1 commit into from
Nov 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 92 additions & 38 deletions pkgs/swarmauri/swarmauri/llms/concrete/CohereModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ class CohereModel(LLMBase):

_BASE_URL: str = PrivateAttr("https://api.cohere.ai/v1")
_client: httpx.Client = PrivateAttr()
_async_client: httpx.AsyncClient = PrivateAttr()

api_key: str
allowed_models: List[str] = [
Expand Down Expand Up @@ -56,7 +55,13 @@ def __init__(self, **data):
"authorization": f"Bearer {self.api_key}",
}
self._client = httpx.Client(headers=headers, base_url=self._BASE_URL)
self._async_client = httpx.AsyncClient(headers=headers, base_url=self._BASE_URL)

def get_headers(self) -> Dict[str, str]:
return {
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {self.api_key}",
}

def _format_messages(
self, messages: List[MessageBase]
Expand Down Expand Up @@ -220,24 +225,48 @@ async def apredict(self, conversation, temperature=0.7, max_tokens=256):
if system_message:
payload["preamble"] = system_message

with DurationManager() as prompt_timer:
response = await self._async_client.post("/chat", json=payload)
response.raise_for_status()
data = response.json()
async with httpx.AsyncClient(
headers=self.get_headers(), base_url=self._BASE_URL
) as client:
with DurationManager() as prompt_timer:
response = await client.post("/chat", json=payload)
response.raise_for_status()
data = response.json()

with DurationManager() as completion_timer:
message_content = data["text"]
with DurationManager() as completion_timer:
message_content = data["text"]

usage_data = data.get("usage", {})
usage_data = data.get("usage", {})

usage = self._prepare_usage_data(
usage_data, prompt_timer.duration, completion_timer.duration
)
usage = self._prepare_usage_data(
usage_data, prompt_timer.duration, completion_timer.duration
)

conversation.add_message(AgentMessage(content=message_content, usage=usage))
return conversation
conversation.add_message(AgentMessage(content=message_content, usage=usage))
return conversation

def stream(self, conversation, temperature=0.7, max_tokens=256) -> Iterator[str]:
"""
Stream responses from the model synchronously, yielding content as it becomes available.

This method processes the conversation and streams the model's response piece by piece,
allowing for real-time processing of the output. At the end of streaming, it adds the
complete response to the conversation history.

Args:
conversation: The conversation object containing message history
temperature (float, optional): Sampling temperature. Controls randomness in the response.
Higher values (e.g., 0.8) create more diverse outputs, while lower values (e.g., 0.2)
make outputs more deterministic. Defaults to 0.7.
max_tokens (int, optional): Maximum number of tokens to generate in the response.
Defaults to 256.

Yields:
str: Chunks of the model's response as they become available.

Returns:
None: The method updates the conversation object in place after completion.
"""
chat_history, system_message, message = self._format_messages(
conversation.history
)
Expand Down Expand Up @@ -285,6 +314,28 @@ def stream(self, conversation, temperature=0.7, max_tokens=256) -> Iterator[str]
async def astream(
self, conversation, temperature=0.7, max_tokens=256
) -> AsyncIterator[str]:
"""
Stream responses from the model asynchronously, yielding content as it becomes available.

This method is the asynchronous version of `stream()`. It processes the conversation and
streams the model's response piece by piece using async/await syntax. The method creates
and manages its own AsyncClient instance to prevent event loop issues.

Args:
conversation: The conversation object containing message history
temperature (float, optional): Sampling temperature. Controls randomness in the response.
Higher values (e.g., 0.8) create more diverse outputs, while lower values (e.g., 0.2)
make outputs more deterministic. Defaults to 0.7.
max_tokens (int, optional): Maximum number of tokens to generate in the response.
Defaults to 256.

Yields:
str: Chunks of the model's response as they become available.

Returns:
None: The method updates the conversation object in place after completion.
"""

chat_history, system_message, message = self._format_messages(
conversation.history
)
Expand All @@ -310,30 +361,33 @@ async def astream(
collected_content = []
usage_data = {}

with DurationManager() as prompt_timer:
response = await self._async_client.post("/chat", json=payload)
response.raise_for_status()

with DurationManager() as completion_timer:
async for line in response.aiter_lines():
if line:
try:
chunk = json.loads(line)
if "text" in chunk:
content = chunk["text"]
collected_content.append(content)
yield content
elif "usage" in chunk:
usage_data = chunk["usage"]
except json.JSONDecodeError:
continue

full_content = "".join(collected_content)
usage = self._prepare_usage_data(
usage_data, prompt_timer.duration, completion_timer.duration
)

conversation.add_message(AgentMessage(content=full_content, usage=usage))
async with httpx.AsyncClient(
headers=self.get_headers(), base_url=self._BASE_URL
) as client:
with DurationManager() as prompt_timer:
response = await client.post("/chat", json=payload)
response.raise_for_status()

with DurationManager() as completion_timer:
async for line in response.aiter_lines():
if line:
try:
chunk = json.loads(line)
if "text" in chunk:
content = chunk["text"]
collected_content.append(content)
yield content
elif "usage" in chunk:
usage_data = chunk["usage"]
except json.JSONDecodeError:
continue

full_content = "".join(collected_content)
usage = self._prepare_usage_data(
usage_data, prompt_timer.duration, completion_timer.duration
)

conversation.add_message(AgentMessage(content=full_content, usage=usage))

def batch(self, conversations: List, temperature=0.7, max_tokens=256) -> List:
"""
Expand Down