From b1cd9cbada325a0a2494155b8739a93b215128fa Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 12 May 2024 20:49:28 -0400 Subject: [PATCH] Support new optimizer Schedule free (#1250) * init * use no schedule * fix typo * update for eval() * fix typo --- fine_tune.py | 35 ++++++++++++++++++++------ library/train_util.py | 15 +++++++++++ sdxl_train.py | 24 ++++++++++++++---- sdxl_train_control_net_lllite.py | 18 ++++++++++++-- sdxl_train_control_net_lllite_old.py | 23 ++++++++++++++--- train_controlnet.py | 16 +++++++++--- train_db.py | 35 ++++++++++++++++++++------ train_network.py | 37 ++++++++++++++++++++++------ train_textual_inversion.py | 31 ++++++++++++++++++----- train_textual_inversion_XTI.py | 23 ++++++++++++++--- 10 files changed, 209 insertions(+), 48 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index d865cd2de..a3d5da922 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -255,18 +255,31 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder) else: ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet) - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - ds_model, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree"): + ds_model, optimizer, train_dataloader = accelerator.prepare( + ds_model, optimizer, train_dataloader + ) + else: + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) training_models = [ds_model] else: # acceleratorがなんかよろしくやってくれるらしい if args.train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree"): + unet, text_encoder, optimizer, train_dataloader = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader + ) + else: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if args.optimizer_type.lower().endswith("schedulefree"): + unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: @@ -328,6 +341,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): m.train() for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step with accelerator.accumulate(*training_models): with torch.no_grad(): @@ -400,9 +415,13 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not args.optimizer_type.lower().endswith("schedulefree"): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/library/train_util.py b/library/train_util.py index 1f9f3c5df..31201db16 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4248,6 +4248,21 @@ def get_optimizer(args, trainable_params): logger.info(f"use AdamW optimizer | {optimizer_kwargs}") optimizer_class = torch.optim.AdamW optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type.endswith("schedulefree".lower()): + try: + import schedulefree as sf + except ImportError: + raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") + if optimizer_type == "AdamWScheduleFree".lower(): + optimizer_class = sf.AdamWScheduleFree + logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}") + elif optimizer_type == "SGDScheduleFree".lower(): + optimizer_class = sf.SGDScheduleFree + logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}") + else: + raise ValueError(f"Unknown optimizer type: {optimizer_type}") + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) if optimizer is None: # 任意のoptimizerを使う diff --git a/sdxl_train.py b/sdxl_train.py index 9e20c60ca..5dd14e2b7 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -466,9 +466,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): text_encoder2=text_encoder2 if train_text_encoder2 else None, ) # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - ds_model, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree"): + ds_model, optimizer, train_dataloader = accelerator.prepare( + ds_model, optimizer, train_dataloader + ) + else: + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) training_models = [ds_model] else: @@ -479,7 +484,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): text_encoder1 = accelerator.prepare(text_encoder1) if train_text_encoder2: text_encoder2 = accelerator.prepare(text_encoder2) - optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + if args.optimizer_type.lower().endswith("schedulefree"): + optimizer, train_dataloader = accelerator.prepare(optimizer, train_dataloader) + else: + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) if args.fused_backward_pass: # use fused optimizer for backward pass: other optimizers will be supported in the future @@ -605,6 +613,8 @@ def optimizer_hook(parameter: torch.Tensor): m.train() for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step if args.fused_optimizer_groups: @@ -740,7 +750,8 @@ def optimizer_hook(parameter: torch.Tensor): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not args.optimizer_type.lower().endswith("schedulefree"): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) else: # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook @@ -749,6 +760,9 @@ def optimizer_hook(parameter: torch.Tensor): for i in range(1, len(optimizers)): lr_schedulers[i].step() + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 301310901..3e3e0b695 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -287,11 +287,19 @@ def train(args): unet.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if args.optimizer_type.lower().endswith("schedulefree"): + unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) if args.gradient_checkpointing: + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる + else: + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() unet.eval() # TextEncoderの出力をキャッシュするときにはCPUへ移動する @@ -391,6 +399,8 @@ def remove_model(old_ckpt_name): current_epoch.value = epoch + 1 for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step with accelerator.accumulate(unet): with torch.no_grad(): @@ -486,9 +496,13 @@ def remove_model(old_ckpt_name): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not args.optimizer_type.lower().endswith("schedulefree"): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 292a0463a..7d1a4978e 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -254,15 +254,24 @@ def train(args): network.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, network, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree"): + unet, network, optimizer, train_dataloader = accelerator.prepare( + unet, network, optimizer, train_dataloader + ) + else: + unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, network, optimizer, train_dataloader, lr_scheduler + ) network: control_net_lllite.ControlNetLLLite if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() else: unet.eval() + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() network.prepare_grad_etc() @@ -357,6 +366,8 @@ def remove_model(old_ckpt_name): network.on_epoch_start() # train() for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step with accelerator.accumulate(network): with torch.no_grad(): @@ -449,9 +460,13 @@ def remove_model(old_ckpt_name): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not args.optimizer_type.lower().endswith("schedulefree"): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/train_controlnet.py b/train_controlnet.py index c9ac6c5a8..800b82fa7 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -298,9 +298,14 @@ def __contains__(self, name): controlnet.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - controlnet, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree"): + controlnet, optimizer, train_dataloader = accelerator.prepare( + controlnet, optimizer, train_dataloader + ) + else: + controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnet, optimizer, train_dataloader, lr_scheduler + ) unet.requires_grad_(False) text_encoder.requires_grad_(False) @@ -417,6 +422,8 @@ def remove_model(old_ckpt_name): current_epoch.value = epoch + 1 for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step with accelerator.accumulate(controlnet): with torch.no_grad(): @@ -500,6 +507,9 @@ def remove_model(old_ckpt_name): lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/train_db.py b/train_db.py index 39d8ea6ed..ec4c642d0 100644 --- a/train_db.py +++ b/train_db.py @@ -229,19 +229,32 @@ def train(args): ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder) else: ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet) - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - ds_model, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree"): + ds_model, optimizer, train_dataloader = accelerator.prepare( + ds_model, optimizer, train_dataloader + ) + else: + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) training_models = [ds_model] else: if train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree"): + unet, text_encoder, optimizer, train_dataloader = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader + ) + else: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) training_models = [unet, text_encoder] else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if args.optimizer_type.lower().endswith("schedulefree"): + unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) training_models = [unet] if not train_text_encoder: @@ -307,6 +320,8 @@ def train(args): text_encoder.train() for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step # 指定したステップ数でText Encoderの学習を止める if global_step == args.stop_text_encoder_training: @@ -384,9 +399,13 @@ def train(args): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not args.optimizer_type.lower().endswith("schedulefree"): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/train_network.py b/train_network.py index b272a6e1a..2e8be8938 100644 --- a/train_network.py +++ b/train_network.py @@ -439,9 +439,14 @@ def train(self, args): text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None, network=network, ) - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - ds_model, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree"): + ds_model, optimizer, train_dataloader = accelerator.prepare( + ds_model, optimizer, train_dataloader + ) + else: + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) training_model = ds_model else: if train_unet: @@ -456,15 +461,23 @@ def train(self, args): text_encoders = [text_encoder] else: pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set - - network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - network, optimizer, train_dataloader, lr_scheduler - ) + + if args.optimizer_type.lower().endswith("schedulefree"): + network, optimizer, train_dataloader = accelerator.prepare( + network, optimizer, train_dataloader + ) + else: + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + network, optimizer, train_dataloader, lr_scheduler + ) training_model = network if args.gradient_checkpointing: # according to TI example in Diffusers, train is required + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() unet.train() + for t_enc in text_encoders: t_enc.train() @@ -473,6 +486,8 @@ def train(self, args): t_enc.text_model.embeddings.requires_grad_(True) else: + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() unet.eval() for t_enc in text_encoders: t_enc.eval() @@ -825,6 +840,8 @@ def remove_model(old_ckpt_name): accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step with accelerator.accumulate(training_model): on_step_start(text_encoder, unet) @@ -930,7 +947,8 @@ def remove_model(old_ckpt_name): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not args.optimizer_type.lower().endswith("schedulefree"): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) if args.scale_weight_norms: @@ -941,6 +959,9 @@ def remove_model(old_ckpt_name): else: keys_scaled, mean_norm, maximum_norm = None, None, None + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index ade077c36..ba8d59a24 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -416,14 +416,24 @@ def train(self, args): # acceleratorがなんかよろしくやってくれるらしい if len(text_encoders) == 1: - text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder_or_list, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree"): + text_encoder_or_list, optimizer, train_dataloader = accelerator.preparet( + text_encoder_or_list, optimizer, train_dataloader + ) + else: + text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.preparet( + text_encoder_or_list, optimizer, train_dataloader, lr_scheduler + ) elif len(text_encoders) == 2: - text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree"): + text_encoder1, text_encoder2, optimizer, train_dataloader = accelerator.prepare( + text_encoders[0], text_encoders[1], optimizer, train_dataloader + ) + else: + text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler + ) text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2] @@ -452,8 +462,12 @@ def train(self, args): unet.to(accelerator.device, dtype=weight_dtype) if args.gradient_checkpointing: # according to TI example in Diffusers, train is required # TODO U-Netをオリジナルに置き換えたのでいらないはずなので、後で確認して消す + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() unet.train() else: + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() unet.eval() if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する @@ -557,6 +571,8 @@ def remove_model(old_ckpt_name): loss_total = 0 for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step with accelerator.accumulate(text_encoders[0]): with torch.no_grad(): @@ -627,6 +643,9 @@ def remove_model(old_ckpt_name): index_no_updates ] + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index efb59137b..f9c6c8e15 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -335,9 +335,14 @@ def train(args): lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # acceleratorがなんかよろしくやってくれるらしい - text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree"): + text_encoder, optimizer, train_dataloader = accelerator.prepare( + text_encoder, optimizer, train_dataloader + ) + else: + text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, optimizer, train_dataloader, lr_scheduler + ) index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0] # logger.info(len(index_no_updates), torch.sum(index_no_updates)) @@ -354,8 +359,12 @@ def train(args): unet.to(accelerator.device, dtype=weight_dtype) if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() else: unet.eval() + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() if not cache_latents: vae.requires_grad_(False) @@ -438,6 +447,8 @@ def remove_model(old_ckpt_name): loss_total = 0 for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step with accelerator.accumulate(text_encoder): with torch.no_grad(): @@ -496,7 +507,8 @@ def remove_model(old_ckpt_name): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not args.optimizer_type.lower().endswith("schedulefree"): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) # Let's make sure we don't update any embedding weights besides the newly added token @@ -505,6 +517,9 @@ def remove_model(old_ckpt_name): index_no_updates ] + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1)