Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] add lvwerra/trl reward modelling support to the ArgillaTrainer #3377

Closed
davidberenstein1957 opened this issue Jul 11, 2023 · 0 comments · Fixed by #3467
Closed
Assignees
Labels
area: trainer Indicates that an issue or pull request is related to the Argilla Trainer type: enhancement Indicates new feature requests
Milestone

Comments

@davidberenstein1957
Copy link
Member

Is your feature request related to a problem? Please describe.
I cannot do RewardModelling with the RankingQuestion-

Describe the solution you'd like
We should add support for https://github.com/lvwerra/trl.
This support should include

  • a .for_reward_modelling() class-method for the TrainingTaskMapping.
  • a top-n(=2) unification support for the RankingQuestionStrategy
  • support for the prepare_for_training() method of the FeedbackDataset
  • support for the FrameWork(Enum) via trl
  • alignment with our docs and usage of chosen-rejected

Describe alternatives you've considered
consider taking carperAI/trlx into account too. #3324

Additional context
N.A.

@davidberenstein1957 davidberenstein1957 added the type: enhancement Indicates new feature requests label Jul 11, 2023
@davidberenstein1957 davidberenstein1957 changed the title [FEATURE] add lvwerra/trl support to the ArgillaTrainer [FEATURE] add lvwerra/trl reward modelling support to the ArgillaTrainer Jul 11, 2023
@tomaarsen tomaarsen self-assigned this Jul 27, 2023
@davidberenstein1957 davidberenstein1957 added this to the 1.15.0 milestone Aug 22, 2023
@davidberenstein1957 davidberenstein1957 added the area: trainer Indicates that an issue or pull request is related to the Argilla Trainer label Aug 28, 2023
davidberenstein1957 added a commit that referenced this issue Aug 28, 2023
…), DPO, rename TrainingTaskMapping (#3467)

Resolves #3379, resolves #3377

Hello!

## Pull Request overview
* Prepare data for SFT, RM, DPO in TRL.
* Rename `TrainingTaskMapping` to `TrainingTask` and `task_mapping` to
`task`.

# Description
## Prepare data
```python
from argilla.feedback import TrainingTask

def formatting_func(sample: Dict[str, Any]):
    ...
    yield template.format(
        prompt=sample["prompt"],
        response=sample["response"],
    )

task = TrainingTask.for_supervised_fine_tuning(formatting_func=formatting_func)
ds = fds_dataset.prepare_for_training(framework="trl", task=task)
# -> ds has "text" and "id" columns
```
Compatible with
[SFTTrainer](https://huggingface.co/docs/trl/main/en/sft_trainer).

```python
task = TrainingTask.for_reward_modelling(chosen_rejected_func=chosen_rejected_func)
ds = fds_dataset.prepare_for_training(framework="trl", task=task)
# -> ds has "chosen" and "rejected" columns
```
Nearly compatible with
[RewardTrainer](https://huggingface.co/docs/trl/main/en/reward_trainer).

```python
task = TrainingTask.for_direct_preference_optimization(prompt_chosen_rejected_func=prompt_chosen_rejected_func)
ds = fds_dataset.prepare_for_training(framework="trl", task=task)
# -> ds has "prompt", "chosen" and "rejected" columns
```
Compatible with
[DPOTrainer](https://huggingface.co/docs/trl/main/en/dpo_trainer).

### Details
I implement this by calling `dataset.format_as("datasets")` and then
passing each sample (a simple dictionary) from this dataset to the
function that the user provides. This user provided function can return
`None`, one sample, a list of samples, or yield samples. This allows
users to export multiple training samples from a single Argilla record,
e.g. when there's multiple annotators that provided useful corrections,
or if the annotated record justifies 3 "chosen", "rejected" pairs
because there's a ranking between 3 texts.

## Rename
`TrainingTaskMapping` is now `TrainingTask` - the "mapping" part is just
unintuitive to the user. Same for `task_mapping` to `task`. **Note:** If
people used `task_mapping=...` before, that will now fail. I can make
this deprecation softer, but then I have to make `task` optional, which
I would rather not do.

## TODO:
- [ ] Add TRL to `ArgillaTrainer`, allowing:
   ```python
   task = TrainingTask.for_supervised_fine_tuning(
       formatting_func=formatting_func
   ) # or any other task from this PR
   trainer = ArgillaTrainer(
       dataset=fds_dataset,
       task=task,
       framework="trl",
   )
   trainer.train()
   ```
- [ ] Consider renaming `FeedbackDataset.prepare_for_training` to
`FeedbackDataset.export`.
- [ ] New tests
- [ ] Add documentation

**Type of change**

- [x] New feature

**How Has This Been Tested**

Not finished yet.

**Checklist**

- [ ] I added relevant documentation
- [x] follows the style guidelines of this project
- [x] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [ ] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---

- Tom Aarsen

---------

Co-authored-by: Alvaro Bartolome <alvarobartt@gmail.com>
Co-authored-by: Alvaro Bartolome <alvaro@argilla.io>
Co-authored-by: David Berenstein <david.m.berenstein@gmail.com>
Co-authored-by: Daniel Vila Suero <daniel@argilla.io>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area: trainer Indicates that an issue or pull request is related to the Argilla Trainer type: enhancement Indicates new feature requests
Projects
None yet
2 participants