Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix dpo_trainer bug for LLMs without bos_token in config #1885

Merged
merged 4 commits into from
Jul 31, 2024

Conversation

DZ9
Copy link
Contributor

@DZ9 DZ9 commented Jul 29, 2024

In dpo_trainer.py, the bos_token will be automatically added if the first token in current tokenized sentence is not equal to bos_token. Currently this bos_token is read from tokenizer config file, but in some LLMs, like Qwen2, bos_token is leaving to None in config, which results in None is added to the input_ids tensor , as shown below:
image

and then the followiing trace will be raised when running dpo:

Traceback (most recent call last):
  File "examples/scripts/dpo.py", line 185, in <module>
    trainer.train()
  File "/usr/lib/python3.10/site-packages/transformers/trainer.py", line 1859, in train
    return inner_training_loop(
  File "/usr/lib/python3.10/site-packages/transformers/trainer.py", line 2165, in _inner_training_loop
    for step, inputs in enumerate(epoch_iterator):
  File "/usr/lib/python3.10/site-packages/accelerate/data_loader.py", line 452, in __iter__
    current_batch = next(dataloader_iter)
  File "/usr/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
    data = self._next_data()
  File "/usr/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 675, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/usr/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/usr/lib/python3.10/site-packages/trl/trainer/utils.py", line 461, in __call__
    torch.tensor(ex[k], dtype=dtype)
TypeError: 'NoneType' object cannot be interpreted as an integer
  0%|          | 0/25 [00:00<?, ?it/s]  

Also, in LLMs like Qwen2, bos_token(<|endoftext|>) is not equal to the first template token(<|im_start|>). Automatically add this bos_token without user awareness will cause unexpected behavior when using the trained dpo model to inference. Actually the bos_token should only be added when this value in tokenizer config is not None.

After this PR, the data is running normally like this:
image

and the model can be trained normally:
image

Test command is:

python examples/scripts/dpo.py
  --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style \
  --model_name_or_path Qwen2/Qwen2-0.5B-Instruct \
  --per_device_train_batch_size 1 \
  --learning_rate 1e-3 \
  --gradient_accumulation_steps 1 \
  --logging_steps 10 \
  --eval_steps 500 \
  --output_dir ./output \
  --warmup_steps 150 \
  --report_to tensorboard \
  --bf16 \
  --logging_first_step \
  --max_length 4096 \
  --sanity_check True

@kashif
Copy link
Collaborator

kashif commented Jul 30, 2024

@DZ9 since this snippet is used in a number of other trainers, would it be better to add it as a helper function and then use it in DPO, CPO, ORPO trainers?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@DZ9
Copy link
Contributor Author

DZ9 commented Jul 31, 2024

@kashif Sure. I've changed the adding of bos and eos token function to utils.py, and applied them in dpo, orpo and cpo trainer.

@kashif
Copy link
Collaborator

kashif commented Jul 31, 2024

thanks @DZ9 can you also run pre-commit run --all-files in the root of TRL folder to fix formatting issues

@DZ9
Copy link
Contributor Author

DZ9 commented Jul 31, 2024

@kashif Absolutely. Done with running the formatting command and commited.

trl/trainer/utils.py Outdated Show resolved Hide resolved
@FlyingDutchman26
Copy link

I just met this problem today, thanks

@kashif kashif merged commit ddf4c8d into huggingface:main Jul 31, 2024
4 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants