diff --git a/helpers/training/save_hooks.py b/helpers/training/save_hooks.py index fa45f2c1..a0b73677 100644 --- a/helpers/training/save_hooks.py +++ b/helpers/training/save_hooks.py @@ -316,6 +316,8 @@ def _save_full_model(self, models, weights, output_dir): def save_model_hook(self, models, weights, output_dir): # Write "training_state.json" to the output directory containing the training state + if not self.accelerator.is_main_process: + return StateTracker.save_training_state( os.path.join(output_dir, "training_state.json") )