-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate OrpoTrainer with PyTorchXLA for faster step time on TPUs #2001
Conversation
trl/trainer/orpo_trainer.py
Outdated
pad_value = self.padding_value | ||
elif k.endswith("_attention_mask"): | ||
pad_value = 0 | ||
batch[k] = pad_list_to_length(batch[k], self.max_length, pad_value=pad_value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
batch[k] = pad_list_to_length(batch[k], self.max_length, pad_value=pad_value) | |
batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k])) |
way faster and does not requires a new helper func
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the suggestion!
@@ -533,7 +536,17 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module | |||
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( | |||
labels=torch.tensor(batch["chosen_labels"]) | |||
) | |||
|
|||
|
|||
if is_torch_xla_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need this only when is_torch_xla_available
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pytorch XLA doesn't support dynamic shape compilation, so we are padding all sequences to the global batch size. This may not be a problem for GPU, so I kept the original algorithm which supports padding to the longest sequence length in the batch.
@@ -35,7 +35,7 @@ | |||
from transformers import AutoModelForCausalLM, DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer | |||
from transformers.trainer_callback import TrainerCallback | |||
from transformers.trainer_utils import EvalLoopOutput | |||
from transformers.utils import is_torch_fx_proxy | |||
from transformers.utils import is_torch_fx_proxy, is_torch_xla_available |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has been added with transformers
v4.39. We should probably set this version as the new minimal requirement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this for catching this!
@@ -659,7 +672,7 @@ def get_batch_logps( | |||
loss_mask = labels != label_pad_token_id | |||
|
|||
# dummy token; we'll ignore the losses on these tokens later | |||
labels[labels == label_pad_token_id] = 0 | |||
labels = torch.where(labels == label_pad_token_id, 0, labels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it a necessary change? Personal opinion, I find it a bit less intuitive to read
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is necessary, because the previous code calls torch.non_zero under the hood, which produces a dynamic shape object and tiggers graph recompilation.
Thank you very much for this addition @wenxindongwork! Unfortunately we can't test with GitHub CI but I'm relying on you for the fact that it works and run faster. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
addressed comments, thanks for the quick review! |
Hello @wenxindongwork can you please fix the code quality issues with |
should work now! |
Can you also set the min transformers version in |
just did, thanks for pointing this out! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating @wenxindongwork - LGTM!
OrpoTrainer currently runs very slow on TPU because the code is not integrated with TorchXLA. Currently there are too many dynamic shapes and data device transfer in the code which trigger graph recompilation and slow down step time. This PR makes changes to improve step time of OrpoTrainer on TPUs by more than 300x. Tested on Llama3-8b, the current step time is 2s using Lora on all linear modules, compared to 10mins which we started with.
The changes should not impact performance on other backends since we have guarded the changes with is_torch_xla_available.