diff --git a/train_network.py b/train_network.py index 76e6cd8a1..d1f02d530 100644 --- a/train_network.py +++ b/train_network.py @@ -897,6 +897,10 @@ def remove_model(old_ckpt_name): for step, batch in enumerate(skipped_dataloader or train_dataloader): current_step.value = global_step + if initial_step > 0: + initial_step -= 1 + continue + with accelerator.accumulate(training_model): on_step_start(text_encoder, unet)