-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batched generation #228
Batched generation #228
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Added some test and ran the following benchmark (cc @edbeeching @natolambert) on the generation part:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome speedup! Looks great in general 🚀
I left some comments, I think we need a safety checker to check if batch_size
is effectively smaller than len(queries)
trl/trainer/ppo_trainer.py
Outdated
if not return_prompt: | ||
output = output[(mask).sum() :] # remove prompt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this is needed? i.e. in which case do we need to return the prompt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we already have the query and it all examples we spend 1-2 lines removing the queries from the generations. with that it's done automatically :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AHh good point ok, you mean here right?
In this case for the API consistency I think that we need to add it as well in the block that does not call batched generate as well.
Also note that for seq2seq models, the model already returns the response without the query: https://github.com/lvwerra/trl/blob/0610711ddab3ba1d8b5d41d31423c213b433472e/examples/sentiment/scripts/t5-sentiment.py#L153 so might be worth it to add another safety checker if self.is_encoder_decoder
-> ignore return_prompt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case for the API consistency I think that we need to add it as well in the block that does not call batched generate as well.
We do here, no?: https://github.com/lvwerra/trl/blob/ee04bada9d4607c8273b41119d9adde98c7c9528/trl/trainer/ppo_trainer.py#L398
Good point about enc-dec, will add an extra clause.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes thanks, indeed we do it
Adressed the comments @younesbelkada ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work! Thanks a lot for this!
This enables to generate in batches rather than one by one with a custom batch size. This simplifies the generation and makes generation significantly faster! Changes are backwards compatible.
Code before:
Code after:
Todo: Add some tests.