Skip to content

Commit

Permalink
microsoft#63 Add support for different models in num_tokens_from_text…
Browse files Browse the repository at this point in the history
… function
  • Loading branch information
vidhula17 committed Oct 3, 2023
1 parent 1fda4b2 commit 45ebbbe
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 33 deletions.
67 changes: 34 additions & 33 deletions autogen/retrieve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,42 +34,43 @@
def num_tokens_from_text(
text: str, model: str = "gpt-3.5-turbo-0613", return_tokens_per_name_and_message: bool = False
) -> Union[int, Tuple[int, int, int]]:
"""Return the number of tokens used by a text."""
# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.debug("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif "gpt-3.5-turbo" in model or "gpt-35-turbo" in model:
logger.warning("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
return num_tokens_from_text(text, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model:
logger.warning("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return num_tokens_from_text(text, model="gpt-4-0613")
"""Return the number of tokens used by a text for different models."""

# Define token counts for known models
known_models = {
"gpt-3.5-turbo-0613": (3, 1),
"gpt-3.5-turbo-16k-0613": (3, 1),
"gpt-4-0314": (3, 1),
"gpt-4-32k-0314": (3, 1),
"gpt-4-0613": (3, 1),
"gpt-4-32k-0613": (3, 1),
}

# Check if the model is known and retrieve token counts
if model in known_models:
tokens_per_message, tokens_per_name = known_models[model]
else:
raise NotImplementedError(
f"""num_tokens_from_text() is not implemented for model {model}. See """
f"""https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are """
f"""converted to tokens."""
)
logger.warning(f"Warning: Model '{model}' is not in known models. Using default token counts.")
# You can add support for additional models and their token counts here.
if model == "your-new-model-name":
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"num_tokens_from_text() is not implemented for model {model}. See "
f"https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are "
f"converted to tokens."
)

# Use tiktoken to calculate the number of tokens in the text
encoding = tiktoken.encoding_for_model(model)
token_count = len(encoding.encode(text))

if return_tokens_per_name_and_message:
return len(encoding.encode(text)), tokens_per_message, tokens_per_name
return token_count, tokens_per_message, tokens_per_name
else:
return len(encoding.encode(text))
return token_count



def num_tokens_from_messages(messages: dict, model: str = "gpt-3.5-turbo-0613"):
Expand Down
21 changes: 21 additions & 0 deletions test/test_num_tokens_from_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import unittest
from autogen.retrieve_utils import num_tokens_from_text

class TestNumTokensFromText(unittest.TestCase):

def test_known_model(self):
# Test with a known model and known token counts
text = "This is a test message."
model = "gpt-3.5-turbo-0613"
result = num_tokens_from_text(text, model)
self.assertEqual(result, 6) # Adjust the expected token count

def test_unknown_model(self):
# Test with an unknown model
text = "This is a test message."
model = "unknown-model"
with self.assertRaises(NotImplementedError):
num_tokens_from_text(text, model)

if __name__ == '__main__':
unittest.main()

0 comments on commit 45ebbbe

Please sign in to comment.