Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched generation #228

Merged
merged 10 commits into from
Mar 21, 2023
Merged

Batched generation #228

merged 10 commits into from
Mar 21, 2023

Conversation

lvwerra
Copy link
Member

@lvwerra lvwerra commented Mar 17, 2023

This enables to generate in batches rather than one by one with a custom batch size. This simplifies the generation and makes generation significantly faster! Changes are backwards compatible.

Code before:

response_tensors = []
for question in tqdm(question_tensors):
    gen_len = output_length_sampler()
    generation_kwargs["max_new_tokens"] = gen_len
    response = ppo_trainer.generate(question, **generation_kwargs)
    response_tensors.append(response.squeeze()[-gen_len:])
batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]

Code after:

response_tensors = ppo_trainer.generate(question_tensors,  return_prompt=False,
    length_sampler= output_length_sampler,  **generation_kwargs)
batch["response"] = tokenizer.batch_decode(response_tensors)

Todo: Add some tests.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 17, 2023

The documentation is not available anymore as the PR was closed or merged.

@lvwerra
Copy link
Member Author

lvwerra commented Mar 21, 2023

Added some test and ran the following benchmark (cc @edbeeching @natolambert) on the generation part:

Model New tokens Batch size Samples Time
GPT-2 64 1 256 176.08
GPT-2 64 2 256 91.54
GPT-2 64 4 256 37.29
GPT-2 64 8 256 23.93
GPT-2 64 16 256 11.58
GPT-2 XL 64 1 256 568.64
GPT-2 XL 64 2 256 379.76
GPT-2 XL 64 4 256 234.63
GPT-2 XL 64 8 256 176.02

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome speedup! Looks great in general 🚀
I left some comments, I think we need a safety checker to check if batch_size is effectively smaller than len(queries)

trl/trainer/ppo_trainer.py Show resolved Hide resolved
trl/trainer/ppo_trainer.py Show resolved Hide resolved
Comment on lines 436 to 437
if not return_prompt:
output = output[(mask).sum() :] # remove prompt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this is needed? i.e. in which case do we need to return the prompt?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we already have the query and it all examples we spend 1-2 lines removing the queries from the generations. with that it's done automatically :)

Copy link
Contributor

@younesbelkada younesbelkada Mar 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AHh good point ok, you mean here right?
In this case for the API consistency I think that we need to add it as well in the block that does not call batched generate as well.
Also note that for seq2seq models, the model already returns the response without the query: https://github.com/lvwerra/trl/blob/0610711ddab3ba1d8b5d41d31423c213b433472e/examples/sentiment/scripts/t5-sentiment.py#L153 so might be worth it to add another safety checker if self.is_encoder_decoder -> ignore return_prompt

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case for the API consistency I think that we need to add it as well in the block that does not call batched generate as well.

We do here, no?: https://github.com/lvwerra/trl/blob/ee04bada9d4607c8273b41119d9adde98c7c9528/trl/trainer/ppo_trainer.py#L398

Good point about enc-dec, will add an extra clause.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes thanks, indeed we do it

trl/trainer/ppo_trainer.py Show resolved Hide resolved
@lvwerra
Copy link
Member Author

lvwerra commented Mar 21, 2023

Adressed the comments @younesbelkada !

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work! Thanks a lot for this!

@lvwerra lvwerra merged commit 9c3e9e4 into main Mar 21, 2023
@lvwerra lvwerra deleted the batched-generation branch March 21, 2023 15:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants