Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate OrpoTrainer with PyTorchXLA for faster step time on TPUs #2001

Merged
merged 17 commits into from
Sep 11, 2024

Conversation

wenxindongwork
Copy link
Contributor

OrpoTrainer currently runs very slow on TPU because the code is not integrated with TorchXLA. Currently there are too many dynamic shapes and data device transfer in the code which trigger graph recompilation and slow down step time. This PR makes changes to improve step time of OrpoTrainer on TPUs by more than 300x. Tested on Llama3-8b, the current step time is 2s using Lora on all linear modules, compared to 10mins which we started with.

The changes should not impact performance on other backends since we have guarded the changes with is_torch_xla_available.

pad_value = self.padding_value
elif k.endswith("_attention_mask"):
pad_value = 0
batch[k] = pad_list_to_length(batch[k], self.max_length, pad_value=pad_value)
Copy link
Member

@qgallouedec qgallouedec Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
batch[k] = pad_list_to_length(batch[k], self.max_length, pad_value=pad_value)
batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k]))

way faster and does not requires a new helper func

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the suggestion!

@@ -533,7 +536,17 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
labels=torch.tensor(batch["chosen_labels"])
)


if is_torch_xla_available():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need this only when is_torch_xla_available?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pytorch XLA doesn't support dynamic shape compilation, so we are padding all sequences to the global batch size. This may not be a problem for GPU, so I kept the original algorithm which supports padding to the longest sequence length in the batch.

@@ -35,7 +35,7 @@
from transformers import AutoModelForCausalLM, DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_torch_fx_proxy
from transformers.utils import is_torch_fx_proxy, is_torch_xla_available
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been added with transformers v4.39. We should probably set this version as the new minimal requirement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this for catching this!

@@ -659,7 +672,7 @@ def get_batch_logps(
loss_mask = labels != label_pad_token_id

# dummy token; we'll ignore the losses on these tokens later
labels[labels == label_pad_token_id] = 0
labels = torch.where(labels == label_pad_token_id, 0, labels)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it a necessary change? Personal opinion, I find it a bit less intuitive to read

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is necessary, because the previous code calls torch.non_zero under the hood, which produces a dynamic shape object and tiggers graph recompilation.

@qgallouedec
Copy link
Member

Thank you very much for this addition @wenxindongwork! Unfortunately we can't test with GitHub CI but I'm relying on you for the fact that it works and run faster.
Can you just address the question/comment? then we're good to merge.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@wenxindongwork
Copy link
Contributor Author

addressed comments, thanks for the quick review!

@lewtun
Copy link
Member

lewtun commented Sep 6, 2024

Hello @wenxindongwork can you please fix the code quality issues with make precommit 🙏 ?

@wenxindongwork
Copy link
Contributor Author

should work now!

@qgallouedec
Copy link
Member

Can you also set the min transformers version in setup.py as well?

@wenxindongwork
Copy link
Contributor Author

just did, thanks for pointing this out!

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating @wenxindongwork - LGTM!

@lewtun lewtun merged commit e2966c8 into huggingface:main Sep 11, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants