Skip to content

Commit

Permalink
Cerebras: Updated streaming, moved exception handling to client class (
Browse files Browse the repository at this point in the history
  • Loading branch information
marklysze authored Oct 2, 2024
1 parent 3d59156 commit fe2bee9
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 38 deletions.
70 changes: 32 additions & 38 deletions autogen/oai/cerebras.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,45 +137,41 @@ def create(self, params: Dict) -> ChatCompletion:
streaming_tool_calls = []

ans = None
try:
response = client.chat.completions.create(**cerebras_params)
except Exception as e:
raise RuntimeError(f"Cerebras exception occurred: {e}")
else:

if cerebras_params["stream"]:
# Read in the chunks as they stream, taking in tool_calls which may be across
# multiple chunks if more than one suggested
ans = ""
for chunk in response:
# Grab first choice, which _should_ always be generated.
ans = ans + (chunk.choices[0].delta.content or "")

if chunk.choices[0].delta.tool_calls:
# We have a tool call recommendation
for tool_call in chunk.choices[0].delta.tool_calls:
streaming_tool_calls.append(
ChatCompletionMessageToolCall(
id=tool_call.id,
function={
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
type="function",
)
response = client.chat.completions.create(**cerebras_params)

if cerebras_params["stream"]:
# Read in the chunks as they stream, taking in tool_calls which may be across
# multiple chunks if more than one suggested
ans = ""
for chunk in response:
# Grab first choice, which _should_ always be generated.
ans = ans + (getattr(chunk.choices[0].delta, "content", None) or "")

if "tool_calls" in chunk.choices[0].delta:
# We have a tool call recommendation
for tool_call in chunk.choices[0].delta["tool_calls"]:
streaming_tool_calls.append(
ChatCompletionMessageToolCall(
id=tool_call["id"],
function={
"name": tool_call["function"]["name"],
"arguments": tool_call["function"]["arguments"],
},
type="function",
)
)

if chunk.choices[0].finish_reason:
prompt_tokens = chunk.x_cerebras.usage.prompt_tokens
completion_tokens = chunk.x_cerebras.usage.completion_tokens
total_tokens = chunk.x_cerebras.usage.total_tokens
else:
# Non-streaming finished
ans: str = response.choices[0].message.content
if chunk.choices[0].finish_reason:
prompt_tokens = chunk.usage.prompt_tokens
completion_tokens = chunk.usage.completion_tokens
total_tokens = chunk.usage.total_tokens
else:
# Non-streaming finished
ans: str = response.choices[0].message.content

prompt_tokens = response.usage.prompt_tokens
completion_tokens = response.usage.completion_tokens
total_tokens = response.usage.total_tokens
prompt_tokens = response.usage.prompt_tokens
completion_tokens = response.usage.completion_tokens
total_tokens = response.usage.total_tokens

if response is not None:
if isinstance(response, Stream):
Expand Down Expand Up @@ -209,8 +205,6 @@ def create(self, params: Dict) -> ChatCompletion:

response_content = response.choices[0].message.content
response_id = response.id
else:
raise RuntimeError("Failed to get response from Cerebras after retrying 5 times.")

# 3. convert output
message = ChatCompletionMessage(
Expand Down
9 changes: 9 additions & 0 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@
ERROR = None

try:
from cerebras.cloud.sdk import ( # noqa
AuthenticationError as cerebras_AuthenticationError,
InternalServerError as cerebras_InternalServerError,
RateLimitError as cerebras_RateLimitError,
)

from autogen.oai.cerebras import CerebrasClient

cerebras_import_exception: Optional[ImportError] = None
Expand Down Expand Up @@ -868,6 +874,9 @@ def yes_or_no_filter(context, response):
ollama_ResponseError,
bedrock_BotoCoreError,
bedrock_ClientError,
cerebras_AuthenticationError,
cerebras_InternalServerError,
cerebras_RateLimitError,
):
logger.debug(f"config {i} failed", exc_info=True)
if i == last:
Expand Down

0 comments on commit fe2bee9

Please sign in to comment.