Skip to content

Commit

Permalink
Re-enable checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
alanwaketan committed Feb 14, 2024
1 parent 6a009a0 commit b4bdf0a
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2004,6 +2004,11 @@ def _inner_training_loop(
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

if self.control.should_epoch_stop or self.control.should_training_stop:
# PyTorch/XLA relies on the data loader to insert the mark_step for
# each step. Since we are breaking the loop early, we need to manually
# insert the mark_step here.
if is_torch_tpu_available():
xm.mark_step()
break
if step < 0:
logger.warning(
Expand Down Expand Up @@ -2987,10 +2992,6 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa

def _save_tpu(self, output_dir: Optional[str] = None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
# TODO: Enable distributed checkpointing with SPMD.
if self.is_fsdp_xla_v2_enabled:
logger.info("Skip saving model for now before the TPU SPMD distributed checkpointing is available")
return

logger.info(f"Saving model checkpoint to {output_dir}")
model = self.model
Expand Down

0 comments on commit b4bdf0a

Please sign in to comment.