Skip to content

Commit

Permalink
Feat: Add support for tokenizer’s or custom jinja chat_template (#1970)
Browse files Browse the repository at this point in the history
* Allow using tokenizer's default chat template with fallbacks

Summary of changes:

1. Adds `tokenizer_default` as option for `chat_template` in
   `chat_template` prompt strategy that allows using the chat template
   from tokenizer's config.json
2. Allows falling back to chat templates available in axolotl if
   tokenizer does not have a chat template
3. Adds a mistral chat template which supports system message - taken
   from https://github.com/chujiezheng/chat_templates/blob/main/chat_templates/mistral-instruct.jinja

---

Why?

Many popular models are not trained with chatml format. As a result for
the model to correctly learn chatml we have to turn on train_on_inputs
which requires more compute and time. If we can use the model's already
learned chat template we can just learn the output tokens

---

Todo:

- Write tests

* Add tests

* Fix lint and bug post merge from main

* Add option `chat_template_jinja` to provide a jinja template

* remove custom mistral template

* Address review comments and add docs

* Update docs/dataset-formats/conversation.qmd

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* fix: set default to tokenizer template

* Merge branch 'main' into cj_tokenizer_default_prompt_template

* chore: remove redundant function

* fix: re-arrange enum declaration position

* fix: refactor artifact left from main merge

* feat(doc): updated config with chat template options and clarified examples

* chore: clarify doc

* chore: added example for non-default template

* chore: refactor

* fix: test

* fix: config being dropped and unittest to catch that

* chore: lint

* chore: skip duplicate

* fix: rename var after merge

* feat: add test for levy's dpo case

* fix: remove default setting on edge case where chat template overriden in dataset section

* feat: handle sharegpt deprecation better in docs

* feat: add example using fallback

* feat: handles chat_template requiring specific user/assistant order

* fix: update test based on new defaults

* fix: imported name incorrectly updated on merge

* chore: lint

* fix: update dummy message to prevent potential overlap with real content

* fix(doc): formatting

* fix: update bradleyterry to use new chat_template

---------

Co-authored-by: Chirag Jain <jain.chirag925@gmail.com>
  • Loading branch information
NanoCode012 and chiragjn authored Oct 29, 2024
1 parent e1e0556 commit bfc77b0
Show file tree
Hide file tree
Showing 20 changed files with 900 additions and 118 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript
type: ... # unimplemented custom format
# fastchat conversation (deprecation soon, use chat_template)
# fastchat conversation (deprecation soon, use chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template)
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
- path: ...
type: sharegpt
Expand Down
57 changes: 53 additions & 4 deletions docs/config.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ lora_on_cpu: true
datasets:
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
data_files: # Optional[str] path to source data files
Expand Down Expand Up @@ -124,6 +124,48 @@ datasets:
# For `completion` datsets only, uses the provided field instead of `text` column
field:

# Using chat template
- path: ...
# Set type to `chat_template` to use this strategy
type: chat_template
# Specify the name of the chat template to use
# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default.
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml.
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
chat_template: tokenizer_default
# Custom jinja template for chat template. This will be only used if `chat_template` is set to `jinja` or empty (in which case chat_template is automatically set to `jinja`).
chat_template_jinja:
# The key in the data example that contains the messages. Default is "messages".
field_messages: messages
# The key in the message turn that contains the role. Default is "role".
message_field_role: role
# The key in the message turn that contains the content. Default is "content".
message_field_content: content
# Optional[Dict[str, List]]. Roles mapping for the messages.
roles:
user: ["human", "user"]
assistant: ["gpt", "assistant", "ai"]
system: ["system"]

## NOTE: Leaving the below empty will default to using the simple legacy tokenization strategy where only last message is trained on.

# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
roles_to_train: ["gpt", "assistant"]
# Optional[str]. Which EOS tokens to train on in the conversation. Possible values are:
# - all: train on all EOS tokens
# - turn: train on the EOS token at the end of each trainable turn
# - last: train on the last EOS token in the conversation
train_on_eos: last
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
message_field_training: training
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
# The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train).
# See example at `docs/dataset-formats/conversation.qmd`
message_field_training_detail: train_detail


# If false, the datasets will not be shuffled and will keep their original order in `datasets`.
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true
Expand All @@ -142,9 +184,16 @@ test_datasets:
# use RL training: 'dpo', 'ipo', 'kto'
rl:

# Saves the desired chat template to the tokenizer_config.json for easier inferencing
# Currently supports chatml and inst (mistral/mixtral)
chat_template: chatml
# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer.
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
# The selected chat template will be saved to the tokenizer_config.json for easier inferencing
# Note: It is recommended to set train_on_inputs to true when using a chat template that is different from the model's default chat template.
chat_template: tokenizer_default
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
chat_template_jinja: null
# Changes the default system message
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
# Axolotl attempts to save the dataset as an arrow after packing the data together so
Expand Down
137 changes: 137 additions & 0 deletions docs/dataset-formats/conversation.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ order: 3

## sharegpt

UPDATE: ShareGPT is being deprecated in the next release. Please see `chat_template` section below.

conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt)

```{.json filename="data.jsonl"}
Expand Down Expand Up @@ -69,3 +71,138 @@ creates a chat where bot is asked to tell a joke, then explain why the joke is f
```{.json filename="data.jsonl"}
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
```


## chat_template

Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.

```{.json filename="data.jsonl"}
{"conversations": [{"role": "...", "content": "..."}]}
```

See `config.qmd` for full configs and supported templates.

### Migrating from sharegpt

Most configs can be adapted as follows:

```yaml
# old
chat_template: chatml
datasets:
- path: ...
type: sharegpt
conversation: chatml

# new (if using tokenizer's chat_template)
datasets:
- path: ...
type: chat_template

field_messages: conversations
message_field_role: from
message_field_content: value

# new (if setting a new chat_template like chatml, gemma, etc)
chat_template: chatml
datasets:
- path: ...
type: chat_template

field_messages: conversations
message_field_role: from
message_field_content: value
```
We recommend checking the below examples for other usecases.
### Examples
1. Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
```yaml
datasets:
- path: ...
type: chat_template
```
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.

```yaml
chat_template: gemma # this overwrites the tokenizer's chat_template
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```

3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.

```yaml
chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```

4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.

```yaml
# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty
chat_template_jinja: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"

datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
For a data sample that looks like:
```{.json filename="data.jsonl"}
{
"conversations": [
{"from": "system", "value": "You are an AI assistant.", "train": false},
{"from": "human", "value": "Hello", "train": false},
{"from": "assistant", "value": "Hello", "train": true},
{"from": "human", "value": "How are you?", "train": true},
{
"from": "assistant",
"value": "I'm doing very well, thank you!",
"train_detail": [
{"begin_offset": 0, "end_offset": 8, "train": false},
{"begin_offset": 9, "end_offset": 18, "train": true},
{"begin_offset": 19, "end_offset": 30, "train": false},
],
},
{
"from": "human",
"value": "I'm doing very well, thank you!",
"train": true,
},
{"from": "assistant", "value": "Hi there!", "train": true}
]
}
```

The configuration would look like:

```yaml
datasets:
- path: ...
type: chat_template
chat_template: tokenizer_default
field_messages: conversations
message_field_role: from
message_field_content: value
roles_to_train: []
train_on_eos: turn
message_field_training: train
message_field_training_detail: train_detail
```
Tip: It is not necessary to use both `message_field_training` and `message_field_training_detail` at a time.
4 changes: 2 additions & 2 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
Expand Down Expand Up @@ -272,7 +272,7 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = chat_templates(cfg.chat_template)
chat_template_str = get_chat_template(cfg.chat_template)

model = model.to(cfg.device, dtype=cfg.torch_dtype)

Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
Expand Down Expand Up @@ -1594,7 +1594,7 @@ def build(self, total_num_steps):
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = chat_templates(
training_arguments_kwargs["chat_template"] = get_chat_template(
self.cfg.chat_template
)

Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/prompt_strategies/bradley_terry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig

LOG = logging.getLogger("axolotl.prompt_strategies")
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry")


def load(strategy, tokenizer, cfg, ds_cfg):
Expand Down
42 changes: 28 additions & 14 deletions src/axolotl/prompt_strategies/bradley_terry/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@
Bradley-Terry model with chat template prompt strategy.
"""

import logging
from typing import Any, Dict, Optional

from axolotl.prompt_strategies.chat_template import (
ChatTemplatePrompter,
ChatTemplateStrategy,
)
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template_from_config

# Configure the logger
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry.chat_template")
LOG.setLevel(logging.INFO)


class BTChatTemplateStrategy(ChatTemplateStrategy):
Expand All @@ -27,18 +32,24 @@ def tokenize_prompt(self, prompt):
# pylint: disable=duplicate-code
prompt[self.messages] = []
if prompt["system"]:
prompt[self.messages].append({"from": "system", "value": prompt["system"]})
prompt[self.messages].append({"from": "user", "value": prompt["input"]})
prompt[self.messages].append({"from": "assistant", "value": prompt["chosen"]})
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super().tokenize_prompt(prompt)

self.messages = "rejected_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
if prompt["system"]:
prompt[self.messages].append({"from": "system", "value": prompt["system"]})
prompt[self.messages].append({"from": "user", "value": prompt["input"]})
prompt[self.messages].append({"from": "assistant", "value": prompt["rejected"]})
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append(
{"role": "assistant", "content": prompt["rejected"]}
)
rejected_tokenized = super().tokenize_prompt(prompt)

return {
Expand All @@ -53,15 +64,18 @@ def tokenize_prompt(self, prompt):

def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)

prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
"message_field_role": ds_cfg.get("message_field_role", "from"),
"message_field_content": ds_cfg.get("message_field_content", "value"),
"message_field_training": ds_cfg.get("message_field_training", "training"),
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail", "train_detail"
"message_field_training_detail", None
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
Expand All @@ -74,8 +88,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]),
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
"roles_to_train": ds_cfg.get("roles_to_train", []),
"train_on_eos": ds_cfg.get("train_on_eos", None),
}

strategy = BTChatTemplateStrategy(
Expand Down
8 changes: 6 additions & 2 deletions src/axolotl/prompt_strategies/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template_from_config

# Configure the logger
LOG = logging.getLogger("axolotl")
Expand Down Expand Up @@ -405,10 +405,14 @@ def get_images(self, prompt):
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
# pylint: disable=duplicate-code
ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")

prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
Expand Down
Loading

0 comments on commit bfc77b0

Please sign in to comment.