-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(openai): add token usage and cost tracking callback (#349)
* feat(openai): add token and cost tracking callback * test(openai_info): add openai info and callback tests * docs: mention how to count the tokens usage --------- Co-authored-by: Gabriele Venturi <lele.venturi@gmail.com>
- Loading branch information
1 parent
de0332f
commit 24ac098
Showing
4 changed files
with
243 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from contextlib import contextmanager | ||
from contextvars import ContextVar | ||
from typing import Optional, Generator | ||
|
||
from openai.openai_object import OpenAIObject | ||
|
||
MODEL_COST_PER_1K_TOKENS = { | ||
# GPT-4 input | ||
"gpt-4": 0.03, | ||
"gpt-4-0613": 0.03, | ||
"gpt-4-32k": 0.06, | ||
"gpt-4-32k-0613": 0.06, | ||
# GPT-3.5 input | ||
"gpt-3.5-turbo": 0.0015, | ||
"gpt-3.5-turbo-0613": 0.0015, | ||
"gpt-3.5-turbo-16k": 0.003, | ||
"gpt-3.5-turbo-16k-0613": 0.003, | ||
# Others | ||
"gpt-35-turbo": 0.002, # Azure OpenAI version of ChatGPT | ||
"text-davinci-003": 0.02, | ||
"text-davinci-002": 0.02, | ||
"code-davinci-002": 0.02, | ||
} | ||
|
||
|
||
def get_openai_token_cost_for_model( | ||
model_name: str, num_tokens: int, | ||
) -> float: | ||
""" | ||
Get the cost in USD for a given model and number of tokens. | ||
Args: | ||
model_name: Name of the model | ||
num_tokens: Number of tokens. | ||
Returns: | ||
Cost in USD. | ||
""" | ||
model_name = model_name.lower() | ||
if model_name not in MODEL_COST_PER_1K_TOKENS: | ||
raise ValueError( | ||
f"Unknown model: {model_name}. Please provide a valid OpenAI model name." | ||
"Known models are: " + ", ".join(MODEL_COST_PER_1K_TOKENS.keys()) | ||
) | ||
return MODEL_COST_PER_1K_TOKENS[model_name] * (num_tokens / 1000) | ||
|
||
|
||
class OpenAICallbackHandler: | ||
"""Callback Handler that tracks OpenAI info.""" | ||
|
||
total_tokens: int = 0 | ||
prompt_tokens: int = 0 | ||
completion_tokens: int = 0 | ||
total_cost: float = 0.0 | ||
|
||
def __repr__(self) -> str: | ||
return ( | ||
f"Tokens Used: {self.total_tokens}\n" | ||
f"\tPrompt Tokens: {self.prompt_tokens}\n" | ||
f"\tCompletion Tokens: {self.completion_tokens}\n" | ||
f"Total Cost (USD): ${self.total_cost:9.6f}" | ||
) | ||
|
||
def __call__(self, response: OpenAIObject) -> None: | ||
"""Collect token usage""" | ||
usage = response.usage | ||
if "total_tokens" not in usage: | ||
return None | ||
model_name = response.model | ||
if model_name in MODEL_COST_PER_1K_TOKENS: | ||
total_cost = get_openai_token_cost_for_model( | ||
model_name, | ||
usage.total_tokens | ||
) | ||
self.total_cost += total_cost | ||
|
||
self.total_tokens += usage.total_tokens | ||
self.prompt_tokens += usage.prompt_tokens | ||
self.completion_tokens += usage.completion_tokens | ||
|
||
def __copy__(self) -> "OpenAICallbackHandler": | ||
"""Return a copy of the callback handler.""" | ||
return self | ||
|
||
|
||
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar( | ||
"openai_callback", default=None | ||
) | ||
|
||
|
||
@contextmanager | ||
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]: | ||
"""Get the OpenAI callback handler in a context manager. | ||
which conveniently exposes token and cost information. | ||
Returns: | ||
OpenAICallbackHandler: The OpenAI callback handler. | ||
Example: | ||
>>> with get_openai_callback() as cb: | ||
... # Use the OpenAI callback handler | ||
""" | ||
cb = OpenAICallbackHandler() | ||
openai_callback_var.set(cb) | ||
yield cb | ||
openai_callback_var.set(None) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import pytest | ||
import openai | ||
from openai.openai_object import OpenAIObject | ||
|
||
from pandasai import PandasAI | ||
from pandasai.helpers.openai_info import ( | ||
OpenAICallbackHandler, | ||
get_openai_callback, | ||
) | ||
import pandas as pd | ||
|
||
from pandasai.llm.openai import OpenAI | ||
|
||
|
||
@pytest.fixture | ||
def handler() -> OpenAICallbackHandler: | ||
return OpenAICallbackHandler() | ||
|
||
|
||
class TestOpenAIInfo: | ||
"""Unit tests for OpenAI Info Callback""" | ||
|
||
def test_handler(self, handler: OpenAICallbackHandler) -> None: | ||
response = OpenAIObject.construct_from({ | ||
"usage": { | ||
"prompt_tokens": 2, | ||
"completion_tokens": 1, | ||
"total_tokens": 3, | ||
}, | ||
"model": "gpt-35-turbo", | ||
}) | ||
|
||
handler(response) | ||
assert handler.total_tokens == 3 | ||
assert handler.prompt_tokens == 2 | ||
assert handler.completion_tokens == 1 | ||
assert handler.total_cost > 0 | ||
|
||
def test_handler_unknown_model(self, handler: OpenAICallbackHandler) -> None: | ||
response = OpenAIObject.construct_from({ | ||
"usage": { | ||
"prompt_tokens": 2, | ||
"completion_tokens": 1, | ||
"total_tokens": 3, | ||
}, | ||
"model": "foo-bar", | ||
}) | ||
|
||
handler(response) | ||
assert handler.total_tokens == 3 | ||
assert handler.prompt_tokens == 2 | ||
assert handler.completion_tokens == 1 | ||
# cost must be 0.0 for unknown model | ||
assert handler.total_cost == 0.0 | ||
|
||
def test_openai_callback(self, mocker): | ||
df = pd.DataFrame([1, 2, 3]) | ||
llm = OpenAI(api_token="test") | ||
llm_response = OpenAIObject.construct_from({ | ||
"choices": [ | ||
{ | ||
"text": "```df.sum()```", | ||
"index": 0, | ||
"logprobs": None, | ||
"finish_reason": "stop", | ||
"start_text": "", | ||
} | ||
], | ||
"model": llm.model, | ||
"usage": { | ||
"prompt_tokens": 2, | ||
"completion_tokens": 1, | ||
"total_tokens": 3, | ||
}, | ||
}) | ||
mocker.patch.object(openai.ChatCompletion, "create", return_value=llm_response) | ||
|
||
pandas_ai = PandasAI(llm, enable_cache=False) | ||
with get_openai_callback() as cb: | ||
_ = pandas_ai(df, "some question") | ||
assert cb.total_tokens == 3 | ||
assert cb.prompt_tokens == 2 | ||
assert cb.completion_tokens == 1 | ||
assert cb.total_cost > 0 | ||
|
||
total_tokens = cb.total_tokens | ||
|
||
with get_openai_callback() as cb: | ||
pandas_ai(df, "some question") | ||
pandas_ai(df, "some question") | ||
|
||
assert cb.total_tokens == total_tokens * 2 | ||
|
||
with get_openai_callback() as cb: | ||
pandas_ai(df, "some question") | ||
pandas_ai(df, "some question") | ||
pandas_ai(df, "some question") | ||
|
||
assert cb.total_tokens == total_tokens * 3 |