diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 19b937c0e2b..3b707cbe223 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -93,7 +93,9 @@ def message_retrieval( """ ... # pragma: no cover - def cost(self, response: ModelClientResponseProtocol) -> float: + def cost( + self, response: ModelClientResponseProtocol, token_cost_1k: Optional[Union[float, Dict[str, float]]] = None + ) -> float: ... # pragma: no cover @staticmethod @@ -275,21 +277,32 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion: return response - def cost(self, response: Union[ChatCompletion, Completion]) -> float: + def cost( + self, + response: Union[ChatCompletion, Completion], + token_cost_1k: Optional[Union[float, Dict[str, float]]] = None, + ) -> float: """Calculate the cost of the response.""" - model = response.model - if model not in OAI_PRICE1K: - # TODO: add logging to warn that the model is not found - logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=True) - return 0 + # No cost specified. Use the default if possible + if token_cost_1k is None: + model = response.model + if model not in OAI_PRICE1K: + # TODO: add logging to warn that the model is not found + logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=True) + return 0 + token_cost_1k = OAI_PRICE1K[model] + if isinstance(token_cost_1k, tuple): + token_cost_1k = {"input": token_cost_1k[0], "output": token_cost_1k[1]} + + # Read the token use n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr] n_output_tokens = response.usage.completion_tokens if response.usage is not None else 0 # type: ignore [union-attr] - tmp_price1K = OAI_PRICE1K[model] - # First value is input token rate, second value is output token rate - if isinstance(tmp_price1K, tuple): - return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000 # type: ignore [no-any-return] - return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000 # type: ignore [operator] + + # Compute final cost + if isinstance(token_cost_1k, dict): + return (token_cost_1k["input"] * n_input_tokens + token_cost_1k["output"] * n_output_tokens) / 1000 # type: ignore [no-any-return] + return token_cost_1k * (n_input_tokens + n_output_tokens) / 1000 # type: ignore [operator] @staticmethod def get_usage(response: Union[ChatCompletion, Completion]) -> Dict: @@ -314,6 +327,8 @@ class OpenAIWrapper: "api_version", "api_type", "tags", + "window_size", + "token_cost_1k", } openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs) @@ -551,6 +566,7 @@ def yes_or_no_filter(context, response): cache = extra_kwargs.get("cache") filter_func = extra_kwargs.get("filter_func") context = extra_kwargs.get("context") + token_cost_1k = extra_kwargs.get("token_cost_1k") total_usage = None actual_usage = None @@ -575,7 +591,7 @@ def yes_or_no_filter(context, response): response.cost # type: ignore [attr-defined] except AttributeError: # update attribute if cost is not calculated - response.cost = client.cost(response) + response.cost = client.cost(response, token_cost_1k=token_cost_1k) cache.set(key, response) total_usage = client.get_usage(response) # check the filter @@ -605,7 +621,7 @@ def yes_or_no_filter(context, response): raise else: # add cost calculation before caching no matter filter is passed or not - response.cost = client.cost(response) + response.cost = client.cost(response, token_cost_1k=token_cost_1k) actual_usage = client.get_usage(response) total_usage = actual_usage.copy() if actual_usage is not None else total_usage self._update_usage(actual_usage=actual_usage, total_usage=total_usage) diff --git a/test/oai/test_custom_client.py b/test/oai/test_custom_client.py index 04669a3e02f..dc024e8de57 100644 --- a/test/oai/test_custom_client.py +++ b/test/oai/test_custom_client.py @@ -49,7 +49,7 @@ def create(self, params): def message_retrieval(self, response): return [response.choices[0].message.content] - def cost(self, response) -> float: + def cost(self, response, token_cost_1k=None) -> float: """Calculate the cost of the response.""" response.cost = TEST_COST return TEST_COST @@ -97,7 +97,7 @@ def create(self, params): def message_retrieval(self, response): return [] - def cost(self, response) -> float: + def cost(self, response, token_cost_1k=None) -> float: return 0 @staticmethod @@ -127,7 +127,7 @@ def create(self, params): def message_retrieval(self, response): return [] - def cost(self, response) -> float: + def cost(self, response, token_cost_1k=None) -> float: return 0 @staticmethod @@ -179,7 +179,7 @@ def create(self, params): def message_retrieval(self, response): return [] - def cost(self, response) -> float: + def cost(self, response, token_cost_1k=None) -> float: """Calculate the cost of the response.""" return 0