Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move BCO to separate BCOTrainer with fixes #1869

Merged
merged 22 commits into from
Jul 28, 2024
Merged

Conversation

claralp
Copy link
Contributor

@claralp claralp commented Jul 24, 2024

This is a further improvement of #1803:
As requested, the BCO training is now fully moved to a new BCOTrainer with respective BCOConfig.

This PR includes the following changes from #1803:

  • When using the BCO loss type there is no need for the KL dataset anymore. By skipping the creation of the KL dataset we save time and memory. -> Fixed with separate trainer now
  • Do not assert both desired and undesired data in each per-device mini batch. With this it was impossible to train large models that only allow a per-device batch size of 1. Also KTO and BCO are supposed to work with unpaired preference data and different amounts of desired or undesired examples. In my experiments BCO proofed to work well without this assumption and also I cannot find it mentioned anywhere in the paper.
  • When checkpointing, also save the RunningMoments object which is used in BCO to calculate the $\delta$ mean reward value. If this is not saved, it will corrupt the whole resumed training process when restarting from a checkpoint.

Additional issue that is fixed only in this PR:

  • UDM supports non interleaved datasets now. The previously implemented underlying distribution matching assumed the same shape of prompt_embeddings across all devices in a distributed setting. The beginning of the train_dataset was interleaved, but the eval_dataset and the end of the training dataset not, causing this to crash in the current main version. As for KTO, I assume that perfectly interleaved datasets should not be a requirement of BCO to work.

@kashif, @lewtun and @seanexp as original author of BCO

@kashif
Copy link
Collaborator

kashif commented Jul 24, 2024

very cool @claralp checking!

@kashif
Copy link
Collaborator

kashif commented Jul 24, 2024

@claralp the failing tests are not related to this PR

docs/source/bco_trainer.mdx Outdated Show resolved Hide resolved
claralp and others added 2 commits July 25, 2024 10:43
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
trl/trainer/bco_config.py Outdated Show resolved Hide resolved
trl/trainer/bco_config.py Outdated Show resolved Hide resolved
trl/trainer/bco_config.py Outdated Show resolved Hide resolved
trl/trainer/bco_config.py Outdated Show resolved Hide resolved
trl/trainer/bco_config.py Outdated Show resolved Hide resolved
@seanexp
Copy link
Contributor

seanexp commented Jul 25, 2024

Sorry for late reply @claralp

The code looks mostly good!

kashif and others added 2 commits July 25, 2024 11:14
Co-authored-by: Seungjae Jung <seanexplode@gmail.com>
trl/trainer/utils.py Outdated Show resolved Hide resolved
Co-authored-by: Seungjae Jung <seanexplode@gmail.com>
trl/trainer/bco_trainer.py Outdated Show resolved Hide resolved
Co-authored-by: Seungjae Jung <seanexplode@gmail.com>
Copy link
Contributor

@seanexp seanexp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @claralp ! All looks good to me :)

trl/trainer/bco_config.py Outdated Show resolved Hide resolved
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@claralp
Copy link
Contributor Author

claralp commented Jul 25, 2024

@kashif @seanexp thanks for reviewing and for the additional fixes/improvements 👍 just checked them, they all look reasonable to me

trl/trainer/bco_trainer.py Outdated Show resolved Hide resolved
@claralp
Copy link
Contributor Author

claralp commented Jul 26, 2024

@kashif @seanexp fixed a few more issues when running on multiple GPUs after the lastest changes.
Is this ready to merge from your side or any more open points that you discovered?

@kashif kashif merged commit 9929370 into huggingface:main Jul 28, 2024
5 of 18 checks passed
qgallouedec added a commit that referenced this pull request Jul 30, 2024
commit 890232f
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Tue Jul 30 14:29:47 2024 +0200

    update example overview (#1883)

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 9929370
Author: Clara Pohland <54847419+claralp@users.noreply.github.com>
Date:   Sun Jul 28 21:10:08 2024 +0200

    Move BCO to separate BCOTrainer with fixes (#1869)

    * kto_trainer: skip KL data for BCO

    * kto_trainer: BCO allow no positives or no negatives in batch

    * kto_trainer: make RunningMoments object serializable

    * add BCOTrainer

    * fix BCO UDM for not interleaved data

    * kto_trainer: remove unused UDM part

    * bco_trainer: add tests and docs, minor fixes

    * code style fixes

    * Update docs/source/bco_trainer.mdx

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * fix BCO UDM for bfloat16

    * Update trl/trainer/bco_config.py

    * Update trl/trainer/bco_config.py

    Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

    * Update trl/trainer/utils.py

    Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

    * Update trl/trainer/bco_trainer.py

    Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

    * Update trl/trainer/bco_config.py

    * Update _toctree.yml

    * Update trl/trainer/bco_config.py

    * Update trl/trainer/bco_trainer.py

    * RunningMoments, fix multi GPU serialization

    * fix tests

    ---------

    Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
    Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

commit 6171cdd
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Sun Jul 28 15:51:38 2024 +0200

    Re-add BigBird Pegasus save/load test (#1882)

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 33d2151
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Sun Jul 28 15:07:10 2024 +0200

    Re-add BigBird Pegasus save/load test (#1876)

    * skip bigbird in ci

    * readd big bird test

    * pytest parametrize

    * dont check the version

    * rm model name

    * re add big bird

    * Merge branch 'main' into readd-bigbird-save-load-test

    ---------

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 8bd2ab8
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Sun Jul 28 14:06:19 2024 +0200

    Refactor judges (#1856)

    * BaseJudge -> BasePairwiseJudge

    * hf judge asyncio

    * refactor judges

    * doc

    * doc

    * doc

    * memeber judge

    * :inherited-members:

    * :inherited-members:

    * doc

    * give up

    * judge tldr with judge class

    * fix rank in multithread

    * format

    * improve doc

    * update doc

    * typo doc

    * doc online dpo

    * Update judge_tldr.py

    ---------

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 82b07d6
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Fri Jul 26 11:43:48 2024 +0200

    Llama in modelling value head tests (#1878)

commit 72bf6c2
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Fri Jul 26 11:33:07 2024 +0200

    Skip BigBird save and load test until next transformers version (#1874)

commit 74e54b5
Author: Edward Beeching <edbeeching@users.noreply.github.com>
Date:   Fri Jul 26 09:36:25 2024 +0200

    fix online dpo example (#1879)

commit 3930973
Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com>
Date:   Thu Jul 25 14:17:37 2024 +0530

    Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM (#1861)

    * Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM

    Added ```dataset_text_field``` in the SFTConfig while training

    * Update docs/source/sft_trainer.mdx

    ---------

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit db8e09e
Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com>
Date:   Thu Jul 25 14:06:57 2024 +0530

    Import missing ```setup_chat_format``` (#1862)

commit 1dae55f
Author: elie <97572401+eliebak@users.noreply.github.com>
Date:   Thu Jul 25 10:27:34 2024 +0200

    add fsdp_qlora config and bnb_4bit_quant_storage (#1863)

commit c8cef79
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Jul 24 21:06:57 2024 +0200

    arXiv to HF Papers (#1870)

commit 7dcf437
Author: Kashif Rasul <kashif.rasul@gmail.com>
Date:   Wed Jul 24 12:27:50 2024 +0200

    [online-DPO] online dpo cleanups (#1864)

    * online dpo cleanups

    * remove unused self.policy

    * add OnlineDPOTrainer and config to __init__.py

    * import from trainer

    * online dpo test

    * rename policy to model and ref_policy to ref_model

    * renamed internally

    * formatting

commit 4e85bd7
Author: Costa Huang <costa.huang@outlook.com>
Date:   Thu Jul 18 14:35:31 2024 -0400

    Online DPO and Online trainer refactor (#1809)

    * online dpo trainer based on rloo trainer

    * push changes

    * refactor

    * use `batch_generation` method

    * precommit

    * remove breakpoint()

    * quick refactor

    * push the current changes

    * quick change

    * refactor

    * use the config name as the experiment name

    * fix logging

    * update online DPO docs

    * push docs

    * increment global step so tensorboard works again.

    * precommit

    * remove unused common online trainer

    * add online DPO docs

    * quick refactor

    * push changes

    * Update docs/source/online_dpo_trainer.md

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

    ---------

    Co-authored-by: Michael Noukhovitch <mnoukhov@gmail.com>
    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit c9d5636
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Thu Jul 18 18:28:49 2024 +0200

    rm token (#1852)
qgallouedec added a commit that referenced this pull request Aug 2, 2024
* fix vsft example commands

* fix use_cache and get tokenizer from processor

* rm unused AutoTokenizer

* Squashed commit of the following:

commit 8bd2ab8
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Sun Jul 28 14:06:19 2024 +0200

    Refactor judges (#1856)

    * BaseJudge -> BasePairwiseJudge

    * hf judge asyncio

    * refactor judges

    * doc

    * doc

    * doc

    * memeber judge

    * :inherited-members:

    * :inherited-members:

    * doc

    * give up

    * judge tldr with judge class

    * fix rank in multithread

    * format

    * improve doc

    * update doc

    * typo doc

    * doc online dpo

    * Update judge_tldr.py

    ---------

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 82b07d6
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Fri Jul 26 11:43:48 2024 +0200

    Llama in modelling value head tests (#1878)

commit 72bf6c2
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Fri Jul 26 11:33:07 2024 +0200

    Skip BigBird save and load test until next transformers version (#1874)

commit 74e54b5
Author: Edward Beeching <edbeeching@users.noreply.github.com>
Date:   Fri Jul 26 09:36:25 2024 +0200

    fix online dpo example (#1879)

commit 3930973
Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com>
Date:   Thu Jul 25 14:17:37 2024 +0530

    Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM (#1861)

    * Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM

    Added ```dataset_text_field``` in the SFTConfig while training

    * Update docs/source/sft_trainer.mdx

    ---------

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit db8e09e
Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com>
Date:   Thu Jul 25 14:06:57 2024 +0530

    Import missing ```setup_chat_format``` (#1862)

commit 1dae55f
Author: elie <97572401+eliebak@users.noreply.github.com>
Date:   Thu Jul 25 10:27:34 2024 +0200

    add fsdp_qlora config and bnb_4bit_quant_storage (#1863)

commit c8cef79
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Jul 24 21:06:57 2024 +0200

    arXiv to HF Papers (#1870)

commit 7dcf437
Author: Kashif Rasul <kashif.rasul@gmail.com>
Date:   Wed Jul 24 12:27:50 2024 +0200

    [online-DPO] online dpo cleanups (#1864)

    * online dpo cleanups

    * remove unused self.policy

    * add OnlineDPOTrainer and config to __init__.py

    * import from trainer

    * online dpo test

    * rename policy to model and ref_policy to ref_model

    * renamed internally

    * formatting

commit 4e85bd7
Author: Costa Huang <costa.huang@outlook.com>
Date:   Thu Jul 18 14:35:31 2024 -0400

    Online DPO and Online trainer refactor (#1809)

    * online dpo trainer based on rloo trainer

    * push changes

    * refactor

    * use `batch_generation` method

    * precommit

    * remove breakpoint()

    * quick refactor

    * push the current changes

    * quick change

    * refactor

    * use the config name as the experiment name

    * fix logging

    * update online DPO docs

    * push docs

    * increment global step so tensorboard works again.

    * precommit

    * remove unused common online trainer

    * add online DPO docs

    * quick refactor

    * push changes

    * Update docs/source/online_dpo_trainer.md

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

    ---------

    Co-authored-by: Michael Noukhovitch <mnoukhov@gmail.com>
    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit c9d5636
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Thu Jul 18 18:28:49 2024 +0200

    rm token (#1852)

* add section in doc

* Squashed commit of the following:

commit 890232f
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Tue Jul 30 14:29:47 2024 +0200

    update example overview (#1883)

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 9929370
Author: Clara Pohland <54847419+claralp@users.noreply.github.com>
Date:   Sun Jul 28 21:10:08 2024 +0200

    Move BCO to separate BCOTrainer with fixes (#1869)

    * kto_trainer: skip KL data for BCO

    * kto_trainer: BCO allow no positives or no negatives in batch

    * kto_trainer: make RunningMoments object serializable

    * add BCOTrainer

    * fix BCO UDM for not interleaved data

    * kto_trainer: remove unused UDM part

    * bco_trainer: add tests and docs, minor fixes

    * code style fixes

    * Update docs/source/bco_trainer.mdx

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * fix BCO UDM for bfloat16

    * Update trl/trainer/bco_config.py

    * Update trl/trainer/bco_config.py

    Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

    * Update trl/trainer/utils.py

    Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

    * Update trl/trainer/bco_trainer.py

    Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

    * Update trl/trainer/bco_config.py

    * Update _toctree.yml

    * Update trl/trainer/bco_config.py

    * Update trl/trainer/bco_trainer.py

    * RunningMoments, fix multi GPU serialization

    * fix tests

    ---------

    Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
    Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

commit 6171cdd
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Sun Jul 28 15:51:38 2024 +0200

    Re-add BigBird Pegasus save/load test (#1882)

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 33d2151
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Sun Jul 28 15:07:10 2024 +0200

    Re-add BigBird Pegasus save/load test (#1876)

    * skip bigbird in ci

    * readd big bird test

    * pytest parametrize

    * dont check the version

    * rm model name

    * re add big bird

    * Merge branch 'main' into readd-bigbird-save-load-test

    ---------

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 8bd2ab8
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Sun Jul 28 14:06:19 2024 +0200

    Refactor judges (#1856)

    * BaseJudge -> BasePairwiseJudge

    * hf judge asyncio

    * refactor judges

    * doc

    * doc

    * doc

    * memeber judge

    * :inherited-members:

    * :inherited-members:

    * doc

    * give up

    * judge tldr with judge class

    * fix rank in multithread

    * format

    * improve doc

    * update doc

    * typo doc

    * doc online dpo

    * Update judge_tldr.py

    ---------

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 82b07d6
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Fri Jul 26 11:43:48 2024 +0200

    Llama in modelling value head tests (#1878)

commit 72bf6c2
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Fri Jul 26 11:33:07 2024 +0200

    Skip BigBird save and load test until next transformers version (#1874)

commit 74e54b5
Author: Edward Beeching <edbeeching@users.noreply.github.com>
Date:   Fri Jul 26 09:36:25 2024 +0200

    fix online dpo example (#1879)

commit 3930973
Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com>
Date:   Thu Jul 25 14:17:37 2024 +0530

    Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM (#1861)

    * Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM

    Added ```dataset_text_field``` in the SFTConfig while training

    * Update docs/source/sft_trainer.mdx

    ---------

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit db8e09e
Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com>
Date:   Thu Jul 25 14:06:57 2024 +0530

    Import missing ```setup_chat_format``` (#1862)

commit 1dae55f
Author: elie <97572401+eliebak@users.noreply.github.com>
Date:   Thu Jul 25 10:27:34 2024 +0200

    add fsdp_qlora config and bnb_4bit_quant_storage (#1863)

commit c8cef79
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Jul 24 21:06:57 2024 +0200

    arXiv to HF Papers (#1870)

commit 7dcf437
Author: Kashif Rasul <kashif.rasul@gmail.com>
Date:   Wed Jul 24 12:27:50 2024 +0200

    [online-DPO] online dpo cleanups (#1864)

    * online dpo cleanups

    * remove unused self.policy

    * add OnlineDPOTrainer and config to __init__.py

    * import from trainer

    * online dpo test

    * rename policy to model and ref_policy to ref_model

    * renamed internally

    * formatting

commit 4e85bd7
Author: Costa Huang <costa.huang@outlook.com>
Date:   Thu Jul 18 14:35:31 2024 -0400

    Online DPO and Online trainer refactor (#1809)

    * online dpo trainer based on rloo trainer

    * push changes

    * refactor

    * use `batch_generation` method

    * precommit

    * remove breakpoint()

    * quick refactor

    * push the current changes

    * quick change

    * refactor

    * use the config name as the experiment name

    * fix logging

    * update online DPO docs

    * push docs

    * increment global step so tensorboard works again.

    * precommit

    * remove unused common online trainer

    * add online DPO docs

    * quick refactor

    * push changes

    * Update docs/source/online_dpo_trainer.md

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

    ---------

    Co-authored-by: Michael Noukhovitch <mnoukhov@gmail.com>
    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit c9d5636
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Thu Jul 18 18:28:49 2024 +0200

    rm token (#1852)

* simplify script

* doc

* use traning args

* args instead of trianing args

* fix doc

* drop eval

* rm eval section

* re-add bigbirg

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
@claralp claralp mentioned this pull request Aug 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants