-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CPOTrainer #1382
Add CPOTrainer #1382
Conversation
@fe1ixxu how close is the trainer in terms of code to the DPOTrainer? Can one subclass from it? |
@kashif Thanks for the quick response! CPO is an approximation of DPO. The key differences between CPOTrainer and DPOTrainer are:
I'm uncertain whether subclassing CPOTrainer from DPOTrainer is a proper idea, as DPOTrainer introduces numerous features related to reference models that are unnecessary for CPOTrainer. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Hi @kashif CPO docs has been finished now! Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this very nice implementation of CPO @fe1ixxu 🔥 ! I left a few small comments and a suggestion to remove a deepspeed function I don't think we need. Apart from that LGTM!
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Because CPO does not need init for reference model
* add CPOTrainer * add docs * fix formatting * removed precompute_ref_log_probs arg * remove precompute_ref_log_probs * typos * finish cpo trainer doc * remove redundant lines * typo * formatting * compute chosen nll loss also for enc-dec models * fix gradient error of inplace operation for enc-dec models * formatting * use CPOConfig * formatting * use model_init_kwargs from CPOConfig * comments in example * fix doc string * fix typo in docstring * update year * fixed typo * use preference dataset * fix learning rate * move dataset_num_proc to configs * Update cpo paper link from HF: cpo_trainer.mdx Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * update description for CPO: cpo_trainer.mdx Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * remove _prepare_deepspeed for cpo Because CPO does not need init for reference model * Add explanation to CPO loss * format * fix bug when lengths are given * add CPOTrainer to README * fix grammer --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Hi! This PR wants to add CPOTrainer proposed in the paper Contrastive Preference Optimization: Pushing the Boundaries of LLM
Performance in Machine Translation
The CPO method is one of the algorithm for building the state-of-the-art LLM-based translation model: ALMA