Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ORT Trainer #2123

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions optimum/onnxruntime/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""
The ORTTrainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task with ONNX Runtime.
"""

import functools
import math
import os
Expand Down Expand Up @@ -131,11 +132,11 @@ def __init__(self, model, args, label_smoother):
# Label smoothing
self.label_smoother = label_smoother

def forward(self, inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs):
def forward(self, inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs, num_items_in_batch):
# The compute_model_plus_loss_internal is assigned once the class is instantiated.
# It should have same signature as Trainer.compute_loss().
# We do this to avoid potential un-synced states if we duplicated compute loss codes .
return self.compute_model_plus_loss_internal(self._original_model, inputs, return_outputs)
return self.compute_model_plus_loss_internal(self._original_model, inputs, return_outputs, num_items_in_batch)

@property
def module(self):
Expand Down Expand Up @@ -291,14 +292,14 @@ def _set_signature_columns_if_needed(self):
# Labels may be named label or label_ids, the default data collator handles that.
self._signature_columns += list(set(["label", "label_ids"] + self.label_names))

def compute_loss(self, model_with_loss, inputs, return_outputs=False):
def compute_loss(self, model_with_loss, inputs, return_outputs=False, num_items_in_batch=None):
# Run model forward + loss compute.
if isinstance(self.model, ModuleWithLoss):
# ORTModule Does not support the BatchEncoding Type so we have to convert to a dict.
dict_inputs = dict(inputs.items())
return model_with_loss(dict_inputs, return_outputs)
return model_with_loss(dict_inputs, return_outputs, num_items_in_batch)
else:
return super().compute_loss(model_with_loss, inputs, return_outputs)
return super().compute_loss(model_with_loss, inputs, return_outputs, num_items_in_batch)

def train(
self,
Expand Down Expand Up @@ -803,7 +804,9 @@ def get_dataloader_sampler(dataloader):
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)

self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
self._maybe_log_save_evaluate(
tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time
)
else:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

Expand All @@ -818,7 +821,7 @@ def get_dataloader_sampler(dataloader):
self.control.should_training_stop = True

self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start_time was introduced in transformers v4.47 https://github.com/huggingface/transformers/blob/v4.47.0/src/transformers/trainer.py#L3021 so maybe we should add a check for the transformers version (min / max supported version) + tests to make sure we keep compatibility with the min/max versions, wdyt @IlyasMoutawwakil ?


if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
logger.warning(
Expand Down
Loading