-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix DPO data collator #932
Conversation
see [this issue](huggingface#907) for more details.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks! I left one suggestion that can be propagated across some lines below, what do you think? Also let's wait to hear from @kashif to see if this fix is the right fix
Upgraded. |
@kashif Hello, I'll be regretful if you could spare some time and review this PR. |
@nrailg thanks! having a look now... so would a user need to create the DPO type dataset differently due to this? |
NO. The only thing need to change is the prompt. Use '\n' instead of ' ' as separator. |
@nrailg yes this looks great and brings the tokenization more in line with what is done so thank you! I am currently refactoring the DPO collator so i would potentially need to add these changes to that, if you want, I can add you as a collaborator if you want to add it there since I have made helper function for the tokenization, for which we can also add tests? what do you think? see #885 |
prompt_tokens = self.tokenizer(prompt, add_special_tokens=False) | ||
prompt_input_ids = prompt_tokens["input_ids"] | ||
prompt_attention_mask = prompt_tokens["attention_mask"] | ||
assert len(prompt_input_ids) == len(prompt_attention_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert len(prompt_input_ids) == len(prompt_attention_mask) | |
if len(prompt_input_ids) != len(prompt_attention_mask): | |
raise ValueError("Prompt input ids and attention mask should have the same length.") |
if not isinstance(chosen, str): | ||
raise ValueError(f"chosen should be an str but got {type(chosen)}") | ||
chosen_tokens = self.tokenizer(prompt + chosen, add_special_tokens=False) | ||
assert prompt_input_ids == chosen_tokens["input_ids"][: len(prompt_input_ids)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert prompt_input_ids == chosen_tokens["input_ids"][: len(prompt_input_ids)] | |
if prompt_input_ids != chosen_tokens["input_ids"][: len(prompt_input_ids)]: | |
raise ValueError("Prompt input ids and chosen input ids should be the same up to the prompt.") |
if not isinstance(rejected, str): | ||
raise ValueError(f"rejected should be an str but got {type(rejected)}") | ||
rejected_tokens = self.tokenizer(prompt + rejected, add_special_tokens=False) | ||
assert prompt_input_ids == rejected_tokens["input_ids"][: len(prompt_input_ids)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert prompt_input_ids == rejected_tokens["input_ids"][: len(prompt_input_ids)] | |
if prompt_input_ids != rejected_tokens["input_ids"][: len(prompt_input_ids)]: | |
raise ValueError("Prompt input ids and rejected input ids should be the same up to the prompt.") |
@nrailg are your tests passing? i think for the gpt2 tokenizer i am using the tests will not pass for you |
Could u pls tell what case is breaking? I'll try fix it. |
Glad to help. Give me sometime to figuring out what you changed. |
I am still testing on my side, @nrailg, and I am getting an issue where in my eval dataset I get an error in the embedding layer... still trying to figure out why... can you kindly test on your side? we fixed some edge cases not covered by your original PR here too |
Thank you so much for your edge case fixes. I'll try figuring out how to run |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
* use logprobs if it exists in the batch * add features to tokenized batch if in data * make get_batch_logps a static method * add tokenize_batch_element dataset mapper * Remove tokenize_batch method from DPODataCollator * Initial sketch to precompute reference_logps * run ref model via pytorch dataloader * add a padding helper * clean up the helper * use logprob item() * default behaviour * clean up collator * add docstring * copy data back to cpu if needed * use get_train_dataloader methods * fix tests * rename: more explicit variable name precompute_ref_log_probs * improve comment * update comment * Update trl/trainer/dpo_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * refactor models into setup parameters * parametrize precompute_ref_log_probs flag * remove useless test * Update trl/trainer/dpo_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update tests/test_dpo_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update tests/test_dpo_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update trl/trainer/dpo_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update trl/trainer/dpo_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * update function arg name * distinguish between pad token_id and mask values * fix tokenization #932 by @nrailg * fix test * undo test refactor * new line * undo breaking change * Update token counter condition to allow Llama tokenizer * Acount for merged tokens on certain tokenizers such Llama-2 tokenizer * Update variable name to match list value when truncating response * map function on multi-gpu and gather * Add test cases for DPOTrainer tokenization step * revert since we need the prepeared model * Use gather_with_metrics on ref_logps precomputation to keep original dataset size * Add flag to keep track of when ref_logps are precomputed * make variable names private * formatting * if precompute_ref_log_probs is true one can use non-peft to populate log-probs * Use tokenizer padding token unless padding_value is set * Move dataset.map(tokenize_batch) outside dataloader to avoid serialization errors * eval can be none * move to cpu to avoid gpu oom * remove unneeded cast to float32 * remove unneeded * fix merge * fix merge * fix merge * add precompute log-prob status via tqdm * Truncate answer if too longer once prompt has been truncated * Add prompt_input_ids to batch to enable generation * formatting and add lora example * fix formatting * Tokenize row now expects sample to have space on chosen/rejected for llama * Revert "Tokenize row now expects sample to have space on chosen/rejected for llama" This reverts commit dd07a10. * raise error when using zero-3 with precompute_ref_log_probs --------- Co-authored-by: Pablo Vicente Juan <p.vicente.juan@gmail.com> Co-authored-by: Shoaib Burq <saburq@gmail.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
* use logprobs if it exists in the batch * add features to tokenized batch if in data * make get_batch_logps a static method * add tokenize_batch_element dataset mapper * Remove tokenize_batch method from DPODataCollator * Initial sketch to precompute reference_logps * run ref model via pytorch dataloader * add a padding helper * clean up the helper * use logprob item() * default behaviour * clean up collator * add docstring * copy data back to cpu if needed * use get_train_dataloader methods * fix tests * rename: more explicit variable name precompute_ref_log_probs * improve comment * update comment * Update trl/trainer/dpo_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * refactor models into setup parameters * parametrize precompute_ref_log_probs flag * remove useless test * Update trl/trainer/dpo_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update tests/test_dpo_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update tests/test_dpo_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update trl/trainer/dpo_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update trl/trainer/dpo_trainer.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * update function arg name * distinguish between pad token_id and mask values * fix tokenization huggingface#932 by @nrailg * fix test * undo test refactor * new line * undo breaking change * Update token counter condition to allow Llama tokenizer * Acount for merged tokens on certain tokenizers such Llama-2 tokenizer * Update variable name to match list value when truncating response * map function on multi-gpu and gather * Add test cases for DPOTrainer tokenization step * revert since we need the prepeared model * Use gather_with_metrics on ref_logps precomputation to keep original dataset size * Add flag to keep track of when ref_logps are precomputed * make variable names private * formatting * if precompute_ref_log_probs is true one can use non-peft to populate log-probs * Use tokenizer padding token unless padding_value is set * Move dataset.map(tokenize_batch) outside dataloader to avoid serialization errors * eval can be none * move to cpu to avoid gpu oom * remove unneeded cast to float32 * remove unneeded * fix merge * fix merge * fix merge * add precompute log-prob status via tqdm * Truncate answer if too longer once prompt has been truncated * Add prompt_input_ids to batch to enable generation * formatting and add lora example * fix formatting * Tokenize row now expects sample to have space on chosen/rejected for llama * Revert "Tokenize row now expects sample to have space on chosen/rejected for llama" This reverts commit dd07a10. * raise error when using zero-3 with precompute_ref_log_probs --------- Co-authored-by: Pablo Vicente Juan <p.vicente.juan@gmail.com> Co-authored-by: Shoaib Burq <saburq@gmail.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Fix problems described in this issue 907.
Testing code: https://github.com/nrailg/trl-dpo-alpaca-farm-demo
@kashif @lvwerra