Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix DPO data collator #932

Closed
wants to merge 2 commits into from
Closed

Conversation

nrailg
Copy link
Contributor

@nrailg nrailg commented Oct 31, 2023

see [this issue](huggingface#907) for more details.
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks! I left one suggestion that can be propagated across some lines below, what do you think? Also let's wait to hear from @kashif to see if this fix is the right fix

@nrailg
Copy link
Contributor Author

nrailg commented Nov 1, 2023

Looks great, thanks! I left one suggestion that can be propagated across some lines below, what do you think? Also let's wait to hear from @kashif to see if this fix is the right fix

Upgraded.

@nrailg
Copy link
Contributor Author

nrailg commented Nov 3, 2023

@kashif Hello, I'll be regretful if you could spare some time and review this PR.

@kashif
Copy link
Collaborator

kashif commented Nov 3, 2023

@nrailg thanks! having a look now... so would a user need to create the DPO type dataset differently due to this?

@nrailg
Copy link
Contributor Author

nrailg commented Nov 3, 2023

@nrailg thanks! having a look now... so would a user need to create the DPO type dataset differently due to this?

NO. The only thing need to change is the prompt. Use '\n' instead of ' ' as separator.

@kashif
Copy link
Collaborator

kashif commented Nov 3, 2023

@nrailg yes this looks great and brings the tokenization more in line with what is done so thank you! I am currently refactoring the DPO collator so i would potentially need to add these changes to that, if you want, I can add you as a collaborator if you want to add it there since I have made helper function for the tokenization, for which we can also add tests? what do you think? see #885

prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)
prompt_input_ids = prompt_tokens["input_ids"]
prompt_attention_mask = prompt_tokens["attention_mask"]
assert len(prompt_input_ids) == len(prompt_attention_mask)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert len(prompt_input_ids) == len(prompt_attention_mask)
if len(prompt_input_ids) != len(prompt_attention_mask):
raise ValueError("Prompt input ids and attention mask should have the same length.")

if not isinstance(chosen, str):
raise ValueError(f"chosen should be an str but got {type(chosen)}")
chosen_tokens = self.tokenizer(prompt + chosen, add_special_tokens=False)
assert prompt_input_ids == chosen_tokens["input_ids"][: len(prompt_input_ids)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert prompt_input_ids == chosen_tokens["input_ids"][: len(prompt_input_ids)]
if prompt_input_ids != chosen_tokens["input_ids"][: len(prompt_input_ids)]:
raise ValueError("Prompt input ids and chosen input ids should be the same up to the prompt.")

if not isinstance(rejected, str):
raise ValueError(f"rejected should be an str but got {type(rejected)}")
rejected_tokens = self.tokenizer(prompt + rejected, add_special_tokens=False)
assert prompt_input_ids == rejected_tokens["input_ids"][: len(prompt_input_ids)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert prompt_input_ids == rejected_tokens["input_ids"][: len(prompt_input_ids)]
if prompt_input_ids != rejected_tokens["input_ids"][: len(prompt_input_ids)]:
raise ValueError("Prompt input ids and rejected input ids should be the same up to the prompt.")

@kashif
Copy link
Collaborator

kashif commented Nov 3, 2023

@nrailg are your tests passing? i think for the gpt2 tokenizer i am using the tests will not pass for you

kashif added a commit to kashif/trl that referenced this pull request Nov 3, 2023
@nrailg
Copy link
Contributor Author

nrailg commented Nov 4, 2023

@nrailg are your tests passing? i think for the gpt2 tokenizer i am using the tests will not pass for you

  1. I ran the stackllama2 demo, and finished training my own demo.
  2. I didn't find any test case related to DPO data collator by grep "DPODataCollator", so I didn't run make test.

Could u pls tell what case is breaking? I'll try fix it.
Sorry, but in our internal GPU cluster , it's diffcult to run make test due to some network related issues.

@nrailg
Copy link
Contributor Author

nrailg commented Nov 4, 2023

@nrailg yes this looks great and brings the tokenization more in line with what is done so thank you! I am currently refactoring the DPO collator so i would potentially need to add these changes to that, if you want, I can add you as a collaborator if you want to add it there since I have made helper function for the tokenization, for which we can also add tests? what do you think? see #885

Glad to help. Give me sometime to figuring out what you changed.

@nrailg
Copy link
Contributor Author

nrailg commented Nov 16, 2023

@nrailg are your tests passing? i think for the gpt2 tokenizer i am using the tests will not pass for you

@kashif hello, do you still have problems with testing? sorry recently I was very busy and didn't have time to check the progress of this issue.

@kashif
Copy link
Collaborator

kashif commented Nov 16, 2023

I am still testing on my side, @nrailg, and I am getting an issue where in my eval dataset I get an error in the embedding layer... still trying to figure out why... can you kindly test on your side? we fixed some edge cases not covered by your original PR here too

@nrailg
Copy link
Contributor Author

nrailg commented Nov 16, 2023

I am still testing on my side, @nrailg, and I am getting an issue where in my eval dataset I get an error in the embedding layer... still trying to figure out why... can you kindly test on your side? we fixed some edge cases not covered by your original PR here too

Thank you so much for your edge case fixes.

I'll try figuring out how to run make test in my company's weird gpu cluster.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@kashif kashif added the 🏋 DPO Related to DPO label Dec 2, 2023
kashif added a commit that referenced this pull request Dec 12, 2023
* use logprobs if it exists in the batch

* add features to tokenized batch if in data

* make get_batch_logps a static method

* add tokenize_batch_element dataset mapper

* Remove tokenize_batch method from DPODataCollator

* Initial sketch to precompute reference_logps

* run ref model via pytorch dataloader

* add a padding helper

* clean up the helper

* use logprob item()

* default behaviour

* clean up collator

* add docstring

* copy data back to cpu if needed

* use get_train_dataloader methods

* fix tests

* rename: more explicit variable name precompute_ref_log_probs

* improve comment

* update comment

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* refactor models into setup parameters

* parametrize precompute_ref_log_probs flag

* remove useless test

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update tests/test_dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update tests/test_dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* update function arg name

* distinguish between pad token_id and mask values

* fix tokenization #932 by @nrailg

* fix test

* undo test refactor

* new line

* undo breaking change

* Update token counter condition to allow Llama tokenizer

* Acount for merged tokens on certain tokenizers such Llama-2 tokenizer

* Update variable name to match list value when truncating response

* map function on multi-gpu and gather

* Add test cases for DPOTrainer tokenization step

* revert since we need the prepeared model

* Use gather_with_metrics on ref_logps precomputation to keep original dataset size

* Add flag to keep track of when ref_logps are precomputed

* make variable names private

* formatting

* if precompute_ref_log_probs is true one can use non-peft to populate log-probs

* Use tokenizer padding token unless padding_value is set

* Move dataset.map(tokenize_batch) outside dataloader to avoid serialization errors

* eval can be none

* move to cpu to avoid gpu oom

* remove unneeded cast to float32

* remove unneeded

* fix merge

* fix merge

* fix merge

* add precompute log-prob status via tqdm

* Truncate answer if too longer once prompt has been truncated

* Add prompt_input_ids to batch to enable generation

* formatting and add lora example

* fix formatting

* Tokenize row now expects sample to have space on chosen/rejected for llama

* Revert "Tokenize row now expects sample to have space on chosen/rejected for llama"

This reverts commit dd07a10.

* raise error when using zero-3 with precompute_ref_log_probs

---------

Co-authored-by: Pablo Vicente Juan <p.vicente.juan@gmail.com>
Co-authored-by: Shoaib Burq <saburq@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@github-actions github-actions bot closed this Jan 3, 2024
lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
* use logprobs if it exists in the batch

* add features to tokenized batch if in data

* make get_batch_logps a static method

* add tokenize_batch_element dataset mapper

* Remove tokenize_batch method from DPODataCollator

* Initial sketch to precompute reference_logps

* run ref model via pytorch dataloader

* add a padding helper

* clean up the helper

* use logprob item()

* default behaviour

* clean up collator

* add docstring

* copy data back to cpu if needed

* use get_train_dataloader methods

* fix tests

* rename: more explicit variable name precompute_ref_log_probs

* improve comment

* update comment

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* refactor models into setup parameters

* parametrize precompute_ref_log_probs flag

* remove useless test

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update tests/test_dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update tests/test_dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* update function arg name

* distinguish between pad token_id and mask values

* fix tokenization huggingface#932 by @nrailg

* fix test

* undo test refactor

* new line

* undo breaking change

* Update token counter condition to allow Llama tokenizer

* Acount for merged tokens on certain tokenizers such Llama-2 tokenizer

* Update variable name to match list value when truncating response

* map function on multi-gpu and gather

* Add test cases for DPOTrainer tokenization step

* revert since we need the prepeared model

* Use gather_with_metrics on ref_logps precomputation to keep original dataset size

* Add flag to keep track of when ref_logps are precomputed

* make variable names private

* formatting

* if precompute_ref_log_probs is true one can use non-peft to populate log-probs

* Use tokenizer padding token unless padding_value is set

* Move dataset.map(tokenize_batch) outside dataloader to avoid serialization errors

* eval can be none

* move to cpu to avoid gpu oom

* remove unneeded cast to float32

* remove unneeded

* fix merge

* fix merge

* fix merge

* add precompute log-prob status via tqdm

* Truncate answer if too longer once prompt has been truncated

* Add prompt_input_ids to batch to enable generation

* formatting and add lora example

* fix formatting

* Tokenize row now expects sample to have space on chosen/rejected for llama

* Revert "Tokenize row now expects sample to have space on chosen/rejected for llama"

This reverts commit dd07a10.

* raise error when using zero-3 with precompute_ref_log_probs

---------

Co-authored-by: Pablo Vicente Juan <p.vicente.juan@gmail.com>
Co-authored-by: Shoaib Burq <saburq@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🏋 DPO Related to DPO
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants