diff --git a/autogen/oai/cerebras.py b/autogen/oai/cerebras.py index 92f355583a5e..c5e153c54657 100644 --- a/autogen/oai/cerebras.py +++ b/autogen/oai/cerebras.py @@ -38,17 +38,18 @@ "llama3.1-70b": (0.60 / 1000, 0.60 / 1000), } + class CerebrasClient: """Client for Cerebras's API.""" - def __init__(self, **kwargs): + def __init__(self, api_key = None, **kwargs): """Requires api_key or environment variable to be set Args: api_key (str): The API key for using Cerebras (or environment variable CEREBRAS_API_KEY needs to be set) """ # Ensure we have the api_key upon instantiation - self.api_key = kwargs.get("api_key", None) + self.api_key = api_key if not self.api_key: self.api_key = os.getenv("CEREBRAS_API_KEY") @@ -56,7 +57,7 @@ def __init__(self, **kwargs): self.api_key ), "Please include the api_key in your config list entry for Cerebras or set the CEREBRAS_API_KEY env variable." - def message_retrieval(self, response) -> List: + def message_retrieval(self, response: ChatCompletion) -> List: """ Retrieve and return a list of strings or a list of Choice.Message from the response. @@ -65,11 +66,12 @@ def message_retrieval(self, response) -> List: """ return [choice.message for choice in response.choices] - def cost(self, response) -> float: + def cost(self, response: ChatCompletion) -> float: + # Note: This field isn't explicitly in `ChatCompletion`, but is injected during chat creation. return response.cost @staticmethod - def get_usage(response) -> Dict: + def get_usage(response: ChatCompletion) -> Dict: """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" # ... # pragma: no cover return { @@ -144,6 +146,7 @@ def create(self, params: Dict) -> ChatCompletion: # multiple chunks if more than one suggested ans = "" for chunk in response: + # Grab first choice, which _should_ always be generated. ans = ans + (chunk.choices[0].delta.content or "") if chunk.choices[0].delta.tool_calls: @@ -227,6 +230,8 @@ def create(self, params: Dict) -> ChatCompletion: completion_tokens=completion_tokens, total_tokens=total_tokens, ), + # Note: This seems to be a field that isn't in the schema of `ChatCompletion`, so Pydantic + # just adds it dynamically. cost=calculate_cerebras_cost(prompt_tokens, completion_tokens, cerebras_params["model"]), ) diff --git a/website/docs/topics/non-openai-models/cloud-cerebras.ipynb b/website/docs/topics/non-openai-models/cloud-cerebras.ipynb index d7354671853e..73c3cdafa4cb 100644 --- a/website/docs/topics/non-openai-models/cloud-cerebras.ipynb +++ b/website/docs/topics/non-openai-models/cloud-cerebras.ipynb @@ -36,7 +36,7 @@ "source": [ "## Getting Started\n", "\n", - "Cerebras provides a number of models to use, included below. See the list of [models here](https://inference-docs.cerebras.ai/introduction).\n", + "Cerebras provides a number of models to use. See the list of [models here](https://inference-docs.cerebras.ai/introduction).\n", "\n", "See the sample `OAI_CONFIG_LIST` below showing how the Cerebras AI client class is used by specifying the `api_type` as `cerebras`.\n", "```python\n", @@ -223,7 +223,7 @@ ] }, { - "name": "stdin", + "name": "stdout", "output_type": "stream", "text": [ "Replying as User. Provide feedback to Cerebras Assistant. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: \n"