Skip to content

Commit

Permalink
Address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
henrytwo committed Sep 26, 2024
1 parent e45d53a commit 689e71a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
15 changes: 10 additions & 5 deletions autogen/oai/cerebras.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,26 @@
"llama3.1-70b": (0.60 / 1000, 0.60 / 1000),
}


class CerebrasClient:
"""Client for Cerebras's API."""

def __init__(self, **kwargs):
def __init__(self, api_key = None, **kwargs):
"""Requires api_key or environment variable to be set
Args:
api_key (str): The API key for using Cerebras (or environment variable CEREBRAS_API_KEY needs to be set)
"""
# Ensure we have the api_key upon instantiation
self.api_key = kwargs.get("api_key", None)
self.api_key = api_key
if not self.api_key:
self.api_key = os.getenv("CEREBRAS_API_KEY")

assert (
self.api_key
), "Please include the api_key in your config list entry for Cerebras or set the CEREBRAS_API_KEY env variable."

def message_retrieval(self, response) -> List:
def message_retrieval(self, response: ChatCompletion) -> List:
"""
Retrieve and return a list of strings or a list of Choice.Message from the response.
Expand All @@ -65,11 +66,12 @@ def message_retrieval(self, response) -> List:
"""
return [choice.message for choice in response.choices]

def cost(self, response) -> float:
def cost(self, response: ChatCompletion) -> float:
# Note: This field isn't explicitly in `ChatCompletion`, but is injected during chat creation.
return response.cost

@staticmethod
def get_usage(response) -> Dict:
def get_usage(response: ChatCompletion) -> Dict:
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
# ... # pragma: no cover
return {
Expand Down Expand Up @@ -144,6 +146,7 @@ def create(self, params: Dict) -> ChatCompletion:
# multiple chunks if more than one suggested
ans = ""
for chunk in response:
# Grab first choice, which _should_ always be generated.
ans = ans + (chunk.choices[0].delta.content or "")

if chunk.choices[0].delta.tool_calls:
Expand Down Expand Up @@ -227,6 +230,8 @@ def create(self, params: Dict) -> ChatCompletion:
completion_tokens=completion_tokens,
total_tokens=total_tokens,
),
# Note: This seems to be a field that isn't in the schema of `ChatCompletion`, so Pydantic
# just adds it dynamically.
cost=calculate_cerebras_cost(prompt_tokens, completion_tokens, cerebras_params["model"]),
)

Expand Down
4 changes: 2 additions & 2 deletions website/docs/topics/non-openai-models/cloud-cerebras.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"source": [
"## Getting Started\n",
"\n",
"Cerebras provides a number of models to use, included below. See the list of [models here](https://inference-docs.cerebras.ai/introduction).\n",
"Cerebras provides a number of models to use. See the list of [models here](https://inference-docs.cerebras.ai/introduction).\n",
"\n",
"See the sample `OAI_CONFIG_LIST` below showing how the Cerebras AI client class is used by specifying the `api_type` as `cerebras`.\n",
"```python\n",
Expand Down Expand Up @@ -223,7 +223,7 @@
]
},
{
"name": "stdin",
"name": "stdout",
"output_type": "stream",
"text": [
"Replying as User. Provide feedback to Cerebras Assistant. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: \n"
Expand Down

0 comments on commit 689e71a

Please sign in to comment.