Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow costs and window size limits to be specified in the config_list #1682

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def message_retrieval(
"""
... # pragma: no cover

def cost(self, response: ModelClientResponseProtocol) -> float:
def cost(
self, response: ModelClientResponseProtocol, token_cost_1k: Optional[Union[float, Dict[str, float]]] = None
) -> float:
... # pragma: no cover

@staticmethod
Expand Down Expand Up @@ -275,21 +277,32 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion:

return response

def cost(self, response: Union[ChatCompletion, Completion]) -> float:
def cost(
self,
response: Union[ChatCompletion, Completion],
token_cost_1k: Optional[Union[float, Dict[str, float]]] = None,
) -> float:
Copy link
Collaborator

@maxim-saplin maxim-saplin Feb 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not critical, yet it might be reasonable to switch to Decimal type (float smells in currency/money values :) - openai_utils.py included

"""Calculate the cost of the response."""
model = response.model
if model not in OAI_PRICE1K:
# TODO: add logging to warn that the model is not found
logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=True)
return 0

# No cost specified. Use the default if possible
if token_cost_1k is None:
model = response.model
if model not in OAI_PRICE1K:
# TODO: add logging to warn that the model is not found
logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=True)
return 0
token_cost_1k = OAI_PRICE1K[model]
if isinstance(token_cost_1k, tuple):
token_cost_1k = {"input": token_cost_1k[0], "output": token_cost_1k[1]}

# Read the token use
n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr]
n_output_tokens = response.usage.completion_tokens if response.usage is not None else 0 # type: ignore [union-attr]
tmp_price1K = OAI_PRICE1K[model]
# First value is input token rate, second value is output token rate
if isinstance(tmp_price1K, tuple):
return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000 # type: ignore [no-any-return]
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000 # type: ignore [operator]

# Compute final cost
if isinstance(token_cost_1k, dict):
return (token_cost_1k["input"] * n_input_tokens + token_cost_1k["output"] * n_output_tokens) / 1000 # type: ignore [no-any-return]
return token_cost_1k * (n_input_tokens + n_output_tokens) / 1000 # type: ignore [operator]

@staticmethod
def get_usage(response: Union[ChatCompletion, Completion]) -> Dict:
Expand All @@ -314,6 +327,8 @@ class OpenAIWrapper:
"api_version",
"api_type",
"tags",
"window_size",
Copy link
Member

@olgavrou olgavrou Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we doing anything with window size? or planning to do anything with it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Both :)

I have yet to modify WebSurfer (this is still a raft PR), but it will use WindowSize to understand how to break up pages.

I also use it in the complex_tasks branch with SocietyOfMind and will port those changes to main once this is ready and merged.

"token_cost_1k",
}

openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
Expand Down Expand Up @@ -551,6 +566,7 @@ def yes_or_no_filter(context, response):
cache = extra_kwargs.get("cache")
filter_func = extra_kwargs.get("filter_func")
context = extra_kwargs.get("context")
token_cost_1k = extra_kwargs.get("token_cost_1k")

total_usage = None
actual_usage = None
Expand All @@ -575,7 +591,7 @@ def yes_or_no_filter(context, response):
response.cost # type: ignore [attr-defined]
except AttributeError:
# update attribute if cost is not calculated
response.cost = client.cost(response)
response.cost = client.cost(response, token_cost_1k=token_cost_1k)
cache.set(key, response)
total_usage = client.get_usage(response)
# check the filter
Expand Down Expand Up @@ -605,7 +621,7 @@ def yes_or_no_filter(context, response):
raise
else:
# add cost calculation before caching no matter filter is passed or not
response.cost = client.cost(response)
response.cost = client.cost(response, token_cost_1k=token_cost_1k)
actual_usage = client.get_usage(response)
total_usage = actual_usage.copy() if actual_usage is not None else total_usage
self._update_usage(actual_usage=actual_usage, total_usage=total_usage)
Expand Down
8 changes: 4 additions & 4 deletions test/oai/test_custom_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def create(self, params):
def message_retrieval(self, response):
return [response.choices[0].message.content]

def cost(self, response) -> float:
def cost(self, response, token_cost_1k=None) -> float:
"""Calculate the cost of the response."""
response.cost = TEST_COST
return TEST_COST
Expand Down Expand Up @@ -97,7 +97,7 @@ def create(self, params):
def message_retrieval(self, response):
return []

def cost(self, response) -> float:
def cost(self, response, token_cost_1k=None) -> float:
return 0

@staticmethod
Expand Down Expand Up @@ -127,7 +127,7 @@ def create(self, params):
def message_retrieval(self, response):
return []

def cost(self, response) -> float:
def cost(self, response, token_cost_1k=None) -> float:
return 0

@staticmethod
Expand Down Expand Up @@ -179,7 +179,7 @@ def create(self, params):
def message_retrieval(self, response):
return []

def cost(self, response) -> float:
def cost(self, response, token_cost_1k=None) -> float:
"""Calculate the cost of the response."""
return 0

Expand Down
Loading