Skip to content

Commit

Permalink
Update num tokens from text (microsoft#149)
Browse files Browse the repository at this point in the history
* Improve num_tokens_from_text

* Format

* Update comments

* Improve docstrings
  • Loading branch information
thinkall authored Oct 9, 2023
1 parent dbd86b0 commit df8906b
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 39 deletions.
6 changes: 5 additions & 1 deletion autogen/agentchat/contrib/retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def __init__(
- update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True.
- get_or_create (Optional, bool): if True, will create/recreate a collection for the retrieve chat.
This is the same as that used in chromadb. Default is False.
- custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string.
The function should take a string as input and return three integers (token_count, tokens_per_message, tokens_per_name).
Default is None, tiktoken will be used and may not be accurate for non-OpenAI models.
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
"""
super().__init__(
Expand All @@ -152,6 +155,7 @@ def __init__(
self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper()
self.update_context = self._retrieve_config.get("update_context", True)
self._get_or_create = self._retrieve_config.get("get_or_create", False)
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", None)
self._context_max_tokens = self._max_tokens * 0.8
self._collection = False # the collection is not created
self._ipython = get_ipython()
Expand Down Expand Up @@ -191,7 +195,7 @@ def _get_context(self, results):
continue
if results["ids"][0][idx] in self._doc_ids:
continue
_doc_tokens = num_tokens_from_text(doc)
_doc_tokens = num_tokens_from_text(doc, custom_token_count_function=self.custom_token_count_function)
if _doc_tokens > self._context_max_tokens:
func_print = f"Skip doc_id {results['ids'][0][idx]} as it is too long to fit in the context."
print(colored(func_print, "green"), flush=True)
Expand Down
93 changes: 55 additions & 38 deletions autogen/retrieve_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Union, Dict, Tuple
from typing import List, Union, Dict, Tuple, Callable
import os
import requests
from urllib.parse import urlparse
Expand Down Expand Up @@ -33,59 +33,76 @@


def num_tokens_from_text(
text: str, model: str = "gpt-3.5-turbo-0613", return_tokens_per_name_and_message: bool = False
text: str,
model: str = "gpt-3.5-turbo-0613",
return_tokens_per_name_and_message: bool = False,
custom_token_count_function: Callable = None,
) -> Union[int, Tuple[int, int, int]]:
"""Return the number of tokens used by a text."""
# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.debug("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif "gpt-3.5-turbo" in model or "gpt-35-turbo" in model:
logger.warning("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
return num_tokens_from_text(text, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model:
logger.warning("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return num_tokens_from_text(text, model="gpt-4-0613")
"""Return the number of tokens used by a text.
Args:
text (str): The text to count tokens for.
model (Optional, str): The model to use for tokenization. Default is "gpt-3.5-turbo-0613".
return_tokens_per_name_and_message (Optional, bool): Whether to return the number of tokens per name and per
message. Default is False.
custom_token_count_function (Optional, Callable): A custom function to count tokens. Default is None.
Returns:
int: The number of tokens used by the text.
int: The number of tokens per message. Only returned if return_tokens_per_name_and_message is True.
int: The number of tokens per name. Only returned if return_tokens_per_name_and_message is True.
"""
if isinstance(custom_token_count_function, Callable):
token_count, tokens_per_message, tokens_per_name = custom_token_count_function(text)
else:
raise NotImplementedError(
f"""num_tokens_from_text() is not implemented for model {model}. See """
f"""https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are """
f"""converted to tokens."""
)
# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.debug("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
known_models = {
"gpt-3.5-turbo": (3, 1),
"gpt-35-turbo": (3, 1),
"gpt-3.5-turbo-0613": (3, 1),
"gpt-3.5-turbo-16k-0613": (3, 1),
"gpt-3.5-turbo-0301": (4, -1),
"gpt-4": (3, 1),
"gpt-4-0314": (3, 1),
"gpt-4-32k-0314": (3, 1),
"gpt-4-0613": (3, 1),
"gpt-4-32k-0613": (3, 1),
}
tokens_per_message, tokens_per_name = known_models.get(model, (3, 1))
token_count = len(encoding.encode(text))

if return_tokens_per_name_and_message:
return len(encoding.encode(text)), tokens_per_message, tokens_per_name
return token_count, tokens_per_message, tokens_per_name
else:
return len(encoding.encode(text))
return token_count


def num_tokens_from_messages(messages: dict, model: str = "gpt-3.5-turbo-0613"):
def num_tokens_from_messages(
messages: dict,
model: str = "gpt-3.5-turbo-0613",
custom_token_count_function: Callable = None,
custom_prime_count: int = 3,
):
"""Return the number of tokens used by a list of messages."""
num_tokens = 0
for message in messages:
for key, value in message.items():
_num_tokens, tokens_per_message, tokens_per_name = num_tokens_from_text(
value, model=model, return_tokens_per_name_and_message=True
value,
model=model,
return_tokens_per_name_and_message=True,
custom_token_count_function=custom_token_count_function,
)
num_tokens += _num_tokens
if key == "name":
num_tokens += tokens_per_name
num_tokens += tokens_per_message
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
num_tokens += custom_prime_count # With ChatGPT, every reply is primed with <|start|>assistant<|message|>
return num_tokens


Expand Down
9 changes: 9 additions & 0 deletions test/test_retrieve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@


class TestRetrieveUtils:
def test_num_tokens_from_text_custom_token_count_function(self):
def custom_token_count_function(text):
return len(text), 1, 2

text = "This is a sample text."
assert num_tokens_from_text(
text, return_tokens_per_name_and_message=True, custom_token_count_function=custom_token_count_function
) == (22, 1, 2)

def test_num_tokens_from_text(self):
text = "This is a sample text."
assert num_tokens_from_text(text) == len(tiktoken.get_encoding("cl100k_base").encode(text))
Expand Down

0 comments on commit df8906b

Please sign in to comment.