Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing in custom pricing in config_list #2902

Merged
merged 9 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,10 @@ def cost(self, response: Union[ChatCompletion, Completion]) -> float:
"""Calculate the cost of the response."""
model = response.model
if model not in OAI_PRICE1K:
# TODO: add logging to warn that the model is not found
logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=True)
# log warning that the model is not found
logger.warning(
f'Model {model} is not found. The cost will be 0. In your config_list, add field {{"price" : [prompt_price_per_1k, completion_token_price_per_1k]}} for customized pricing.'
)
return 0

n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr]
Expand Down Expand Up @@ -328,6 +330,7 @@ class OpenAIWrapper:
"api_version",
"api_type",
"tags",
"price",
}

openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
Expand Down Expand Up @@ -592,6 +595,14 @@ def yes_or_no_filter(context, response):
filter_func = extra_kwargs.get("filter_func")
context = extra_kwargs.get("context")
agent = extra_kwargs.get("agent")
price = extra_kwargs.get("price", None)
if isinstance(price, list):
price = tuple(price)
elif isinstance(price, float) or isinstance(price, int):
logger.warning(
"Input price is a float/int. Using the same price for prompt and completion tokens. Use a list/tuple if prompt and completion token prices are different."
)
price = (price, price)

total_usage = None
actual_usage = None
Expand Down Expand Up @@ -678,7 +689,10 @@ def yes_or_no_filter(context, response):
raise
else:
# add cost calculation before caching no matter filter is passed or not
response.cost = client.cost(response)
if price is not None:
response.cost = self._cost_with_customized_price(response, price)
else:
response.cost = client.cost(response)
actual_usage = client.get_usage(response)
total_usage = actual_usage.copy() if actual_usage is not None else total_usage
self._update_usage(actual_usage=actual_usage, total_usage=total_usage)
Expand Down Expand Up @@ -712,6 +726,17 @@ def yes_or_no_filter(context, response):
continue # filter is not passed; try the next config
raise RuntimeError("Should not reach here.")

@staticmethod
def _cost_with_customized_price(
response: ModelClient.ModelClientResponseProtocol, price_1k: Tuple[float, float]
) -> None:
"""If a customized cost is passed, overwrite the cost in the response."""
n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr]
n_output_tokens = response.usage.completion_tokens if response.usage is not None else 0 # type: ignore [union-attr]
if n_output_tokens is None:
n_output_tokens = 0
return n_input_tokens * price_1k[0] + n_output_tokens * price_1k[1]

@staticmethod
def _update_dict_from_chunk(chunk: BaseModel, d: Dict[str, Any], field: str) -> int:
"""Update the dict from the chunk.
Expand Down
56 changes: 53 additions & 3 deletions notebook/agentchat_cost_token_tracking.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@
"\n",
"To gather usage data for a list of agents, we provide an utility function `autogen.gather_usage_summary(agents)` where you pass in a list of agents and gather the usage summary.\n",
"\n",
"## 3. Custom token price for up-to-date cost estimation\n",
"AutoGen tries to keep the token prices up-to-date. However, you can pass in a `price` field in `config_list` if the token price is not listed or up-to-date. Please creating an issue or pull request to help us keep the token prices up-to-date!\n",
"\n",
"Note: in json files, the price should be a list of two floats.\n",
"\n",
"Example Usage:\n",
"```python\n",
"{\n",
" \"model\": \"gpt-3.5-turbo-xxxx\",\n",
" \"api_key\": \"YOUR_API_KEY\",\n",
" \"price\": [0.0005, 0.0015]\n",
"}\n",
"```\n",
"\n",
"## Caution when using Azure OpenAI!\n",
"If you are using azure OpenAI, the model returned from completion doesn't have the version information. The returned model is either 'gpt-35-turbo' or 'gpt-4'. From there, we are calculating the cost based on gpt-3.5-turbo-0125: (0.0005, 0.0015) per 1k prompt and completion tokens and gpt-4-0613: (0.03, 0.06). This means the cost can be wrong if you are using a different version from azure OpenAI.\n",
"\n",
Expand All @@ -55,7 +69,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -65,7 +79,7 @@
"config_list = autogen.config_list_from_json(\n",
" \"OAI_CONFIG_LIST\",\n",
" filter_dict={\n",
" \"tags\": [\"gpt-3.5-turbo\", \"gpt-3.5-turbo-16k\"], # comment out to get all\n",
" \"model\": [\"gpt-3.5-turbo\", \"gpt-3.5-turbo-16k\"], # comment out to get all\n",
" },\n",
")"
]
Expand Down Expand Up @@ -127,6 +141,42 @@
"print(response.cost)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## OpenAIWrapper with custom token price"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Price: 109\n"
]
}
],
"source": [
"# Adding price to the config_list\n",
"for i in range(len(config_list)):\n",
" config_list[i][\"price\"] = [\n",
" 1,\n",
" 1,\n",
" ] # Note: This price is just for demonstration purposes. Please replace it with the actual price of the model.\n",
"\n",
"client = OpenAIWrapper(config_list=config_list)\n",
"messages = [\n",
" {\"role\": \"user\", \"content\": \"Can you give me 3 useful tips on learning Python? Keep it simple and short.\"},\n",
"]\n",
"response = client.create(messages=messages, cache_seed=None)\n",
"print(\"Price:\", response.cost)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -504,7 +554,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.9.19"
}
},
"nbformat": 4,
Expand Down
12 changes: 12 additions & 0 deletions test/oai/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,18 @@ def test_cost(cache_seed):
print(response.cost)


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_customized_cost():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST, file_location=KEY_LOC, filter_dict={"tags": ["gpt-3.5-turbo-instruct"]}
)
for config in config_list:
config.update({"price": [1, 1]})
client = OpenAIWrapper(config_list=config_list, cache_seed=None)
response = client.create(prompt="1+3=")
assert response.cost >= 4, "Due to customized pricing, cost should be greater than 4"


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_usage_summary():
config_list = config_list_from_json(
Expand Down
Loading