Skip to content

Commit

Permalink
Update tokenizer apply_chat_template functionality (#647)
Browse files Browse the repository at this point in the history
* Allow custom kwargs in `tokenizer.apply_chat_template`

* Update jinja dependency version

* Add `tokenizer_kwargs` options

* Add support for dictionaries of chat templates in the tokenizer config

* Add `CohereTokenizer`

* `apply_chat_template` is no longer async

* Add unit test for multiple chat templates

* Update tokenizers.js

* Also update when `chat_template` is undefined

* Support setting tokenizer and text from URL

* Update Claude tokenizer display name

* Add Cohere Command-R tokenizer to playground

* Add `Grok1Tokenizer`

* Throw error if chat template object is malformed

* Improved error checking

* Remove redundant error check

* `template_dict` can be a null-prototype object
  • Loading branch information
xenova authored Mar 20, 2024
1 parent 40cdd36 commit f0ef2e8
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 10 deletions.
17 changes: 15 additions & 2 deletions examples/tokenizer-playground/src/App.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@ import { Token } from './components/Token'


function App() {
// Allow user to set tokenizer and text via URL query parameters
const urlParams = new URLSearchParams(window.location.search);
const tokenizerParam = urlParams.get('tokenizer');
const textParam = urlParams.get('text');

const [tokenIds, setTokenIds] = useState([])
const [decodedTokens, setDecodedTokens] = useState([])
const [margins, setMargins] = useState([])
const [outputOption, setOutputOption] = useState('text');
const [tokenizer, setTokenizer] = useState('Xenova/gpt-4');
const [tokenizer, setTokenizer] = useState(tokenizerParam ?? 'Xenova/gpt-4');

const textareaRef = useRef(null);
const outputRef = useRef(null);
Expand Down Expand Up @@ -51,6 +55,12 @@ function App() {
worker.current.postMessage({ model_id, text });
}, [tokenizer]);

useEffect(() => {
if (textParam) {
onInputChange({ target: { value: textParam } });
}
}, [onInputChange, textParam]);

const onTokenizerChange = useCallback((e) => {
const model_id = e.target.value;
setTokenizer(model_id);
Expand All @@ -70,10 +80,12 @@ function App() {
<option value="Xenova/gpt-4">gpt-4 / gpt-3.5-turbo / text-embedding-ada-002</option>
<option value="Xenova/text-davinci-003">text-davinci-003 / text-davinci-002</option>
<option value="Xenova/gpt-3">gpt-3</option>
<option value="Xenova/claude-tokenizer">Claude 3</option>
<option value="Xenova/grok-1-tokenizer">Grok-1</option>
<option value="Xenova/claude-tokenizer">Claude</option>
<option value="Xenova/mistral-tokenizer">Mistral</option>
<option value="Xenova/gemma-tokenizer">Gemma</option>
<option value="Xenova/llama-tokenizer">LLaMA / Llama 2</option>
<option value="Xenova/c4ai-command-r-v01-tokenizer">Cohere Command-R</option>
<option value="Xenova/t5-small">T5</option>
<option value="Xenova/bert-base-cased">bert-base-cased</option>
</select>
Expand All @@ -86,6 +98,7 @@ function App() {
rows="8"
className="font-mono text-lg block w-full p-2.5 text-gray-900 bg-gray-50 rounded-lg border border-gray-200"
placeholder="Enter some text"
defaultValue={textParam ?? textareaRef.current?.value ?? ''}
></textarea>

<div className='flex justify-center gap-5'>
Expand Down
1 change: 1 addition & 0 deletions examples/tokenizer-playground/src/worker.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ self.addEventListener('message', async (event) => {
// NOTE: We just remove the StripDecoder from the llama tokenizer
switch (tokenizer.constructor.name) {
case 'LlamaTokenizer':
case 'Grok1Tokenizer':
// tokenizer.decoder.decoders.at(-1).constructor.name === 'StripDecoder'
tokenizer.decoder.decoders.pop();
break;
Expand Down
8 changes: 4 additions & 4 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"dependencies": {
"onnxruntime-web": "1.14.0",
"sharp": "^0.32.0",
"@huggingface/jinja": "^0.2.1"
"@huggingface/jinja": "^0.2.2"
},
"optionalDependencies": {
"onnxruntime-node": "1.14.0"
Expand Down
52 changes: 51 additions & 1 deletion src/tokenizers.js
Original file line number Diff line number Diff line change
Expand Up @@ -2519,6 +2519,18 @@ export class PreTrainedTokenizer extends Callable {
this.legacy = false;

this.chat_template = tokenizerConfig.chat_template ?? null;
if (Array.isArray(this.chat_template)) {
// Chat templates are stored as lists of dicts with fixed key names,
// we reconstruct that into a single dict while loading them.
const chat_template = Object.create(null);
for (const { name, template } of this.chat_template) {
if (typeof name !== 'string' || typeof template !== 'string') {
throw new Error('Chat template must be a list of objects with "name" and "template" properties');
}
chat_template[name] = template;
}
this.chat_template = chat_template;
}
this._compiled_template_cache = new Map();
}

Expand Down Expand Up @@ -2995,6 +3007,7 @@ export class PreTrainedTokenizer extends Callable {
* @param {number} [options.max_length=null] Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is false.
* If not specified, the tokenizer's `max_length` attribute will be used as a default.
* @param {boolean} [options.return_tensor=true] Whether to return the output as a Tensor or an Array. Has no effect if tokenize is false.
* @param {Object} [options.tokenizer_kwargs={}] Additional options to pass to the tokenizer.
* @returns {string | Tensor | number[]| number[][]} The tokenized output.
*/
apply_chat_template(conversation, {
Expand All @@ -3005,9 +3018,37 @@ export class PreTrainedTokenizer extends Callable {
truncation = false,
max_length = null,
return_tensor = true,
tokenizer_kwargs = {},
...kwargs
} = {}) {

chat_template ??= this.chat_template ?? this.default_chat_template;
// First, handle the cases when the model has a dict of multiple templates
if (
(this.chat_template && typeof this.chat_template === 'object') ||
(this.chat_template === null && this.default_chat_template && typeof this.default_chat_template === 'object')
) {
const template_dict = this.chat_template ?? this.default_chat_template; // Guaranteed to be a non-null object

if (chat_template !== null && Object.hasOwn(template_dict, chat_template)) {
// The user can pass the name of a template to the chat template argument instead of an entire template
chat_template = template_dict[chat_template];
} else if (chat_template === null && 'default' in template_dict) {
chat_template = template_dict['default'];
} else if (chat_template === null) {
throw Error(
`This model has multiple chat templates with no default specified! Please either pass a chat ` +
`template or the name of the template you wish to use to the 'chat_template' argument. Available ` +
`template names are ${Object.keys(template_dict).sort()}.`
)
}
} else {
// These are the cases when the model has a single template
// priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
chat_template ??= this.chat_template ?? this.default_chat_template;
}
if (typeof chat_template !== 'string') {
throw Error(`chat_template must be a string, but got ${typeof chat_template}`);
}

// Compilation function uses a cache to avoid recompiling the same template
let compiledTemplate = this._compiled_template_cache.get(chat_template);
Expand All @@ -3029,6 +3070,7 @@ export class PreTrainedTokenizer extends Callable {
add_generation_prompt: add_generation_prompt,

...special_tokens_map,
...kwargs,
});

if (tokenize) {
Expand All @@ -3038,6 +3080,7 @@ export class PreTrainedTokenizer extends Callable {
truncation,
max_length,
return_tensor,
...tokenizer_kwargs,
}).input_ids;
}

Expand Down Expand Up @@ -3208,6 +3251,8 @@ export class GemmaTokenizer extends PreTrainedTokenizer {
_default_chat_template = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"
}

export class Grok1Tokenizer extends PreTrainedTokenizer { }

/**
* Helper function to build translation inputs for an `NllbTokenizer` or `M2M100Tokenizer`.
* @param {PreTrainedTokenizer} self The tokenizer instance.
Expand Down Expand Up @@ -4263,6 +4308,9 @@ export class VitsTokenizer extends PreTrainedTokenizer {
this.decoder = new VitsDecoder({});
}
}

export class CohereTokenizer extends PreTrainedTokenizer { }

/**
* Helper class which is used to instantiate pretrained tokenizers with the `from_pretrained` function.
* The chosen tokenizer class is determined by the type specified in the tokenizer config.
Expand Down Expand Up @@ -4314,6 +4362,8 @@ export class AutoTokenizer {
VitsTokenizer,
Qwen2Tokenizer,
GemmaTokenizer,
Grok1Tokenizer,
CohereTokenizer,

// Base case:
PreTrainedTokenizer,
Expand Down
40 changes: 38 additions & 2 deletions tests/tokenizers.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,42 @@ describe('Chat templates', () => {
compare(input_ids, [1, 733, 16289, 28793, 22557, 28725, 910, 460, 368, 28804, 733, 28748, 16289, 28793, 28737, 28742, 28719, 2548, 1598, 28723, 1602, 541, 315, 1316, 368, 3154, 28804, 2, 28705, 733, 16289, 28793, 315, 28742, 28715, 737, 298, 1347, 805, 910, 10706, 5752, 1077, 3791, 28808, 733, 28748, 16289, 28793])
});

it('should support multiple chat templates', async () => {

const tokenizer = await AutoTokenizer.from_pretrained("Xenova/c4ai-command-r-v01-tokenizer")

// define conversation input:
const conversation = [
{ role: "user", content: "Whats the biggest penguin in the world?" }
]
// define documents to ground on:
const documents = [
{ title: "Tall penguins", text: "Emperor penguins are the tallest growing up to 122 cm in height." },
{ title: "Penguin habitats", text: "Emperor penguins only live in Antarctica." }
]

// render the RAG prompt as a string:
const grounded_generation_prompt = tokenizer.apply_chat_template(
conversation,
{
chat_template: "rag",
tokenize: false,
add_generation_prompt: true,

documents,
citation_mode: "accurate", // or "fast"
}
)
expect(grounded_generation_prompt).toEqual(
"<BOS_TOKEN><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># Safety Preamble\nThe instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.\n\n" +
"# System Preamble\n## Basic Rules\nYou are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.\n\n" +
"# User Preamble\n## Task and Context\nYou help people answer their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You should focus on serving the user's needs as best you can, which will be wide-ranging.\n\n## Style Guide\nUnless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling.<|END_OF_TURN_TOKEN|>" +
"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>Whats the biggest penguin in the world?<|END_OF_TURN_TOKEN|>" +
"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><results>\nDocument: 0\ntitle: Tall penguins\ntext: Emperor penguins are the tallest growing up to 122 cm in height.\n\nDocument: 1\ntitle: Penguin habitats\ntext: Emperor penguins only live in Antarctica.\n</results><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Carefully perform the following instructions, in order, starting each with a new line.\nFirstly, Decide which of the retrieved documents are relevant to the user's last input by writing 'Relevant Documents:' followed by comma-separated list of document numbers. If none are relevant, you should instead write 'None'.\nSecondly, Decide which of the retrieved documents contain facts that should be cited in a good answer to the user's last input by writing 'Cited Documents:' followed a comma-separated list of document numbers. If you dont want to cite any of them, you should instead write 'None'.\nThirdly, Write 'Answer:' followed by a response to the user's last input in high quality natural english. Use the retrieved documents to help you. Do not insert any citations or grounding markup.\nFinally, Write 'Grounded answer:' followed by a response to the user's last input in high quality natural english. Use the symbols <co: doc> and </co: doc> to indicate when a fact comes from a document in the search result, e.g <co: 0>my fact</co: 0> for a fact from document 0.<|END_OF_TURN_TOKEN|>" +
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
);
});

it('should support user-defined chat template', async () => {
const tokenizer = await AutoTokenizer.from_pretrained("Xenova/llama-tokenizer");

Expand Down Expand Up @@ -395,7 +431,7 @@ describe('Chat templates', () => {
.replaceAll('USE_DEFAULT_PROMPT', true)
.replaceAll('DEFAULT_SYSTEM_MESSAGE', 'You are a helpful, respectful and honest assistant.');

const text = await tokenizer.apply_chat_template(chat, { tokenize: false, return_tensor: false, chat_template });
const text = tokenizer.apply_chat_template(chat, { tokenize: false, return_tensor: false, chat_template });

expect(text).toEqual("<s>[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant.\n<</SYS>>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]");

Expand All @@ -412,7 +448,7 @@ describe('Chat templates', () => {

for (let { messages, add_generation_prompt, tokenize, target } of tests) {

const generated = await tokenizer.apply_chat_template(messages, {
const generated = tokenizer.apply_chat_template(messages, {
tokenize,
add_generation_prompt,
return_tensor: false,
Expand Down

0 comments on commit f0ef2e8

Please sign in to comment.