-
Notifications
You must be signed in to change notification settings - Fork 403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEATURE] add lvwerra/trl
reward modelling
support to the ArgillaTrainer
#3377
Labels
area: trainer
Indicates that an issue or pull request is related to the Argilla Trainer
type: enhancement
Indicates new feature requests
Milestone
Comments
This was referenced Jul 11, 2023
davidberenstein1957
changed the title
[FEATURE] add
[FEATURE] add Jul 11, 2023
lvwerra/trl
support to the ArgillaTrainer
lvwerra/trl
reward modelling
support to the ArgillaTrainer
13 tasks
davidberenstein1957
added
the
area: trainer
Indicates that an issue or pull request is related to the Argilla Trainer
label
Aug 28, 2023
davidberenstein1957
added a commit
that referenced
this issue
Aug 28, 2023
…), DPO, rename TrainingTaskMapping (#3467) Resolves #3379, resolves #3377 Hello! ## Pull Request overview * Prepare data for SFT, RM, DPO in TRL. * Rename `TrainingTaskMapping` to `TrainingTask` and `task_mapping` to `task`. # Description ## Prepare data ```python from argilla.feedback import TrainingTask def formatting_func(sample: Dict[str, Any]): ... yield template.format( prompt=sample["prompt"], response=sample["response"], ) task = TrainingTask.for_supervised_fine_tuning(formatting_func=formatting_func) ds = fds_dataset.prepare_for_training(framework="trl", task=task) # -> ds has "text" and "id" columns ``` Compatible with [SFTTrainer](https://huggingface.co/docs/trl/main/en/sft_trainer). ```python task = TrainingTask.for_reward_modelling(chosen_rejected_func=chosen_rejected_func) ds = fds_dataset.prepare_for_training(framework="trl", task=task) # -> ds has "chosen" and "rejected" columns ``` Nearly compatible with [RewardTrainer](https://huggingface.co/docs/trl/main/en/reward_trainer). ```python task = TrainingTask.for_direct_preference_optimization(prompt_chosen_rejected_func=prompt_chosen_rejected_func) ds = fds_dataset.prepare_for_training(framework="trl", task=task) # -> ds has "prompt", "chosen" and "rejected" columns ``` Compatible with [DPOTrainer](https://huggingface.co/docs/trl/main/en/dpo_trainer). ### Details I implement this by calling `dataset.format_as("datasets")` and then passing each sample (a simple dictionary) from this dataset to the function that the user provides. This user provided function can return `None`, one sample, a list of samples, or yield samples. This allows users to export multiple training samples from a single Argilla record, e.g. when there's multiple annotators that provided useful corrections, or if the annotated record justifies 3 "chosen", "rejected" pairs because there's a ranking between 3 texts. ## Rename `TrainingTaskMapping` is now `TrainingTask` - the "mapping" part is just unintuitive to the user. Same for `task_mapping` to `task`. **Note:** If people used `task_mapping=...` before, that will now fail. I can make this deprecation softer, but then I have to make `task` optional, which I would rather not do. ## TODO: - [ ] Add TRL to `ArgillaTrainer`, allowing: ```python task = TrainingTask.for_supervised_fine_tuning( formatting_func=formatting_func ) # or any other task from this PR trainer = ArgillaTrainer( dataset=fds_dataset, task=task, framework="trl", ) trainer.train() ``` - [ ] Consider renaming `FeedbackDataset.prepare_for_training` to `FeedbackDataset.export`. - [ ] New tests - [ ] Add documentation **Type of change** - [x] New feature **How Has This Been Tested** Not finished yet. **Checklist** - [ ] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [ ] I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- - Tom Aarsen --------- Co-authored-by: Alvaro Bartolome <alvarobartt@gmail.com> Co-authored-by: Alvaro Bartolome <alvaro@argilla.io> Co-authored-by: David Berenstein <david.m.berenstein@gmail.com> Co-authored-by: Daniel Vila Suero <daniel@argilla.io> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
area: trainer
Indicates that an issue or pull request is related to the Argilla Trainer
type: enhancement
Indicates new feature requests
Is your feature request related to a problem? Please describe.
I cannot do RewardModelling with the
RankingQuestion
-Describe the solution you'd like
We should add support for https://github.com/lvwerra/trl.
This support should include
.for_reward_modelling()
class-method for theTrainingTaskMapping
.RankingQuestionStrategy
prepare_for_training()
method of theFeedbackDataset
FrameWork(Enum)
viatrl
chosen-rejected
Describe alternatives you've considered
consider taking
carperAI/trlx
into account too. #3324Additional context
N.A.
The text was updated successfully, but these errors were encountered: