Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update num tokens from text #149

Merged
merged 6 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion autogen/agentchat/contrib/retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def __init__(
- update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True.
- get_or_create (Optional, bool): if True, will create/recreate a collection for the retrieve chat.
This is the same as that used in chromadb. Default is False.
- custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string.
The function should take a string as input and return three integers (token_count, tokens_per_message, tokens_per_name).
Default is None, tiktoken will be used and may not be accurate for non-OpenAI models.
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
"""
super().__init__(
Expand All @@ -152,6 +155,7 @@ def __init__(
self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper()
self.update_context = self._retrieve_config.get("update_context", True)
self._get_or_create = self._retrieve_config.get("get_or_create", False)
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", None)
self._context_max_tokens = self._max_tokens * 0.8
self._collection = False # the collection is not created
self._ipython = get_ipython()
Expand Down Expand Up @@ -191,7 +195,7 @@ def _get_context(self, results):
continue
if results["ids"][0][idx] in self._doc_ids:
continue
_doc_tokens = num_tokens_from_text(doc)
_doc_tokens = num_tokens_from_text(doc, custom_token_count_function=self.custom_token_count_function)
if _doc_tokens > self._context_max_tokens:
func_print = f"Skip doc_id {results['ids'][0][idx]} as it is too long to fit in the context."
print(colored(func_print, "green"), flush=True)
Expand Down
93 changes: 55 additions & 38 deletions autogen/retrieve_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Union, Dict, Tuple
from typing import List, Union, Dict, Tuple, Callable
import os
import requests
from urllib.parse import urlparse
Expand Down Expand Up @@ -33,59 +33,76 @@


def num_tokens_from_text(
text: str, model: str = "gpt-3.5-turbo-0613", return_tokens_per_name_and_message: bool = False
text: str,
model: str = "gpt-3.5-turbo-0613",
return_tokens_per_name_and_message: bool = False,
custom_token_count_function: Callable = None,
) -> Union[int, Tuple[int, int, int]]:
"""Return the number of tokens used by a text."""
# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.debug("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif "gpt-3.5-turbo" in model or "gpt-35-turbo" in model:
logger.warning("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
return num_tokens_from_text(text, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model:
logger.warning("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return num_tokens_from_text(text, model="gpt-4-0613")
"""Return the number of tokens used by a text.

Args:
text (str): The text to count tokens for.
model (Optional, str): The model to use for tokenization. Default is "gpt-3.5-turbo-0613".
return_tokens_per_name_and_message (Optional, bool): Whether to return the number of tokens per name and per
message. Default is False.
custom_token_count_function (Optional, Callable): A custom function to count tokens. Default is None.

Returns:
int: The number of tokens used by the text.
int: The number of tokens per message. Only returned if return_tokens_per_name_and_message is True.
int: The number of tokens per name. Only returned if return_tokens_per_name_and_message is True.
"""
if isinstance(custom_token_count_function, Callable):
token_count, tokens_per_message, tokens_per_name = custom_token_count_function(text)
else:
raise NotImplementedError(
f"""num_tokens_from_text() is not implemented for model {model}. See """
f"""https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are """
f"""converted to tokens."""
)
# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.debug("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
known_models = {
"gpt-3.5-turbo": (3, 1),
"gpt-35-turbo": (3, 1),
"gpt-3.5-turbo-0613": (3, 1),
"gpt-3.5-turbo-16k-0613": (3, 1),
"gpt-3.5-turbo-0301": (4, -1),
"gpt-4": (3, 1),
"gpt-4-0314": (3, 1),
"gpt-4-32k-0314": (3, 1),
"gpt-4-0613": (3, 1),
"gpt-4-32k-0613": (3, 1),
}
tokens_per_message, tokens_per_name = known_models.get(model, (3, 1))
token_count = len(encoding.encode(text))

if return_tokens_per_name_and_message:
return len(encoding.encode(text)), tokens_per_message, tokens_per_name
return token_count, tokens_per_message, tokens_per_name
else:
return len(encoding.encode(text))
return token_count


def num_tokens_from_messages(messages: dict, model: str = "gpt-3.5-turbo-0613"):
def num_tokens_from_messages(
messages: dict,
model: str = "gpt-3.5-turbo-0613",
custom_token_count_function: Callable = None,
custom_prime_count: int = 3,
):
"""Return the number of tokens used by a list of messages."""
num_tokens = 0
for message in messages:
for key, value in message.items():
_num_tokens, tokens_per_message, tokens_per_name = num_tokens_from_text(
value, model=model, return_tokens_per_name_and_message=True
value,
model=model,
return_tokens_per_name_and_message=True,
custom_token_count_function=custom_token_count_function,
)
num_tokens += _num_tokens
if key == "name":
num_tokens += tokens_per_name
num_tokens += tokens_per_message
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
num_tokens += custom_prime_count # With ChatGPT, every reply is primed with <|start|>assistant<|message|>
return num_tokens


Expand Down
9 changes: 9 additions & 0 deletions test/test_retrieve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@


class TestRetrieveUtils:
def test_num_tokens_from_text_custom_token_count_function(self):
def custom_token_count_function(text):
return len(text), 1, 2

text = "This is a sample text."
assert num_tokens_from_text(
text, return_tokens_per_name_and_message=True, custom_token_count_function=custom_token_count_function
) == (22, 1, 2)

def test_num_tokens_from_text(self):
text = "This is a sample text."
assert num_tokens_from_text(text) == len(tiktoken.get_encoding("cl100k_base").encode(text))
Expand Down
Loading