From b2363f1021955c049c98e65676efca130690c40f Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 31 May 2024 12:20:20 +0800 Subject: [PATCH 01/11] Final implementation --- library/train_util.py | 11 ++++- train_network.py | 104 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 106 insertions(+), 9 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 1f9f3c5df..beb33bf82 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -657,8 +657,15 @@ def set_caching_mode(self, mode): def set_current_epoch(self, epoch): if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする - self.shuffle_buckets() - self.current_epoch = epoch + if epoch > self.current_epoch: + logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) + num_epochs = epoch - self.current_epoch + for _ in range(num_epochs): + self.current_epoch += 1 + self.shuffle_buckets() + else: + logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) + self.current_epoch = epoch def set_current_step(self, step): self.current_step = step diff --git a/train_network.py b/train_network.py index b272a6e1a..76e6cd8a1 100644 --- a/train_network.py +++ b/train_network.py @@ -493,17 +493,24 @@ def train(self, args): # before resuming make hook for saving/loading to save/load the network weights only def save_model_hook(models, weights, output_dir): # pop weights of other models than network to save only network weights - # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606 - if accelerator.is_main_process or args.deepspeed: + if accelerator.is_main_process: remove_indices = [] for i, model in enumerate(models): if not isinstance(model, type(accelerator.unwrap_model(network))): remove_indices.append(i) for i in reversed(remove_indices): - if len(weights) > i: - weights.pop(i) + weights.pop(i) # print(f"save model hook: {len(weights)} weights will be saved") + # save current ecpoch and step + train_state_file = os.path.join(output_dir, "train_state.json") + # +1 is needed because the state is saved before current_step is set from global_step + logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}") + with open(train_state_file, "w", encoding="utf-8") as f: + json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f) + + steps_from_state = None + def load_model_hook(models, input_dir): # remove models except network remove_indices = [] @@ -514,6 +521,15 @@ def load_model_hook(models, input_dir): models.pop(i) # print(f"load model hook: {len(models)} models will be loaded") + # load current epoch and step to + nonlocal steps_from_state + train_state_file = os.path.join(input_dir, "train_state.json") + if os.path.exists(train_state_file): + with open(train_state_file, "r", encoding="utf-8") as f: + data = json.load(f) + steps_from_state = data["current_step"] + logger.info(f"load train state from {train_state_file}: {data}") + accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -757,7 +773,53 @@ def load_model_hook(models, input_dir): if key in metadata: minimum_metadata[key] = metadata[key] - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + # calculate steps to skip when resuming or starting from a specific step + initial_step = 0 + if args.initial_epoch is not None or args.initial_step is not None: + # if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming + if steps_from_state is not None: + logger.warning( + "steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます" + ) + if args.initial_step is not None: + initial_step = args.initial_step + else: + # num steps per epoch is calculated by num_processes and gradient_accumulation_steps + initial_step = (args.initial_epoch - 1) * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + else: + # if initial_epoch and initial_step are not specified, steps_from_state is used when resuming + if steps_from_state is not None: + initial_step = steps_from_state + steps_from_state = None + + if initial_step > 0: + assert ( + args.max_train_steps > initial_step + ), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}" + + progress_bar = tqdm( + range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps" + ) + + epoch_to_start = 0 + if initial_step > 0: + if args.skip_until_initial_step: + # if skip_until_initial_step is specified, load data and discard it to ensure the same data is used + if not args.resume: + logger.info( + f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります" + ) + logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします") + initial_step *= args.gradient_accumulation_steps + else: + # if not, only epoch no is skipped for informative purpose + epoch_to_start = initial_step // math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) + initial_step = 0 # do not skip + global_step = 0 noise_scheduler = DDPMScheduler( @@ -816,7 +878,11 @@ def remove_model(old_ckpt_name): self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # training loop - for epoch in range(num_train_epochs): + for skip_epoch in range(epoch_to_start): # skip epochs + logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") + initial_step -= len(train_dataloader) + + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -824,7 +890,12 @@ def remove_model(old_ckpt_name): accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) - for step, batch in enumerate(train_dataloader): + skipped_dataloader = None + if initial_step > 0: + skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step-1) + initial_step = 1 + + for step, batch in enumerate(skipped_dataloader or train_dataloader): current_step.value = global_step with accelerator.accumulate(training_model): on_step_start(text_encoder, unet) @@ -1126,6 +1197,25 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + parser.add_argument( + "--skip_until_initial_step", + action="store_true", + help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする", + ) + parser.add_argument( + "--initial_epoch", + type=int, + default=None, + help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`." + + " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる", + ) + parser.add_argument( + "--initial_step", + type=int, + default=None, + help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch." + + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする", + ) # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") From 3eb27ced52e8bf522c7e490c3dacba1f8597f5b1 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 31 May 2024 12:24:15 +0800 Subject: [PATCH 02/11] Skip the final 1 step --- train_network.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/train_network.py b/train_network.py index 76e6cd8a1..d1f02d530 100644 --- a/train_network.py +++ b/train_network.py @@ -897,6 +897,10 @@ def remove_model(old_ckpt_name): for step, batch in enumerate(skipped_dataloader or train_dataloader): current_step.value = global_step + if initial_step > 0: + initial_step -= 1 + continue + with accelerator.accumulate(training_model): on_step_start(text_encoder, unet) From e5bab69e3a8f3dc4afb1badba65b6c50ca2f36d8 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 2 Jun 2024 21:11:40 +0900 Subject: [PATCH 03/11] fix alpha mask without disk cache closes #1351, ref #1339 --- library/train_util.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 1f9f3c5df..566f59279 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1265,7 +1265,8 @@ def __getitem__(self, index): if subset.alpha_mask: if img.shape[2] == 4: alpha_mask = img[:, :, 3] # [H,W] - alpha_mask = transforms.ToTensor()(alpha_mask) # 0-255 -> 0-1 + alpha_mask = alpha_mask.astype(np.float32) / 255.0 # 0.0~1.0 + alpha_mask = torch.FloatTensor(alpha_mask) else: alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32) else: @@ -2211,7 +2212,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) def load_latents_from_disk( npz_path, -) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: +) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: npz = np.load(npz_path) if "latents" not in npz: raise ValueError(f"error: npz is old format. please re-generate {npz_path}") @@ -2229,7 +2230,7 @@ def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, fli if flipped_latents_tensor is not None: kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() if alpha_mask is not None: - kwargs["alpha_mask"] = alpha_mask # ndarray + kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() np.savez( npz_path, latents=latents_tensor.float().cpu().numpy(), @@ -2496,8 +2497,9 @@ def cache_batch_latents( if image.shape[2] == 4: alpha_mask = image[:, :, 3] # [H,W] alpha_mask = alpha_mask.astype(np.float32) / 255.0 + alpha_mask = torch.FloatTensor(alpha_mask) # [H,W] else: - alpha_mask = np.ones_like(image[:, :, 0], dtype=np.float32) + alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W] else: alpha_mask = None alpha_masks.append(alpha_mask) From 4dbcef429b744d0cc101494802448b8c15f4f674 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 4 Jun 2024 21:26:55 +0900 Subject: [PATCH 04/11] update for corner cases --- library/train_util.py | 3 +++ train_network.py | 23 ++++++++++++++--------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 102f9f03b..4736ff4ff 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -663,6 +663,7 @@ def set_current_epoch(self, epoch): for _ in range(num_epochs): self.current_epoch += 1 self.shuffle_buckets() + # self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader? else: logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) self.current_epoch = epoch @@ -5560,6 +5561,8 @@ def add(self, *, epoch: int, step: int, loss: float) -> None: if epoch == 0: self.loss_list.append(loss) else: + while len(self.loss_list) <= step: + self.loss_list.append(0.0) self.loss_total -= self.loss_list[step] self.loss_list[step] = loss self.loss_total += loss diff --git a/train_network.py b/train_network.py index d1f02d530..7ba073855 100644 --- a/train_network.py +++ b/train_network.py @@ -493,13 +493,15 @@ def train(self, args): # before resuming make hook for saving/loading to save/load the network weights only def save_model_hook(models, weights, output_dir): # pop weights of other models than network to save only network weights - if accelerator.is_main_process: + # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606 + if accelerator.is_main_process or args.deepspeed: remove_indices = [] for i, model in enumerate(models): if not isinstance(model, type(accelerator.unwrap_model(network))): remove_indices.append(i) for i in reversed(remove_indices): - weights.pop(i) + if len(weights) > i: + weights.pop(i) # print(f"save model hook: {len(weights)} weights will be saved") # save current ecpoch and step @@ -813,11 +815,12 @@ def load_model_hook(models, input_dir): ) logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします") initial_step *= args.gradient_accumulation_steps + + # set epoch to start to make initial_step less than len(train_dataloader) + epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) else: # if not, only epoch no is skipped for informative purpose - epoch_to_start = initial_step // math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps - ) + epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) initial_step = 0 # do not skip global_step = 0 @@ -878,9 +881,11 @@ def remove_model(old_ckpt_name): self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # training loop - for skip_epoch in range(epoch_to_start): # skip epochs - logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") - initial_step -= len(train_dataloader) + if initial_step > 0: # only if skip_until_initial_step is specified + for skip_epoch in range(epoch_to_start): # skip epochs + logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") + initial_step -= len(train_dataloader) + global_step = initial_step for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") @@ -892,7 +897,7 @@ def remove_model(old_ckpt_name): skipped_dataloader = None if initial_step > 0: - skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step-1) + skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1) initial_step = 1 for step, batch in enumerate(skipped_dataloader or train_dataloader): From 4ecbac131aba3d121f9708b3ac2a1f4726b17dc0 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Wed, 5 Jun 2024 16:31:44 +0900 Subject: [PATCH 05/11] Bump crate-ci/typos from 1.19.0 to 1.21.0, fix typos, and updated _typos.toml (Close #1307) --- .github/workflows/typos.yml | 2 +- _typos.toml | 2 ++ library/ipex/attention.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index e8b06483f..c81ff3210 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -18,4 +18,4 @@ jobs: - uses: actions/checkout@v4 - name: typos-action - uses: crate-ci/typos@v1.19.0 + uses: crate-ci/typos@v1.21.0 diff --git a/_typos.toml b/_typos.toml index ae9e06b18..bbf7728f4 100644 --- a/_typos.toml +++ b/_typos.toml @@ -2,6 +2,7 @@ # Instruction: https://github.com/marketplace/actions/typos-action#getting-started [default.extend-identifiers] +ddPn08="ddPn08" [default.extend-words] NIN="NIN" @@ -27,6 +28,7 @@ rik="rik" koo="koo" yos="yos" wn="wn" +hime="hime" [files] diff --git a/library/ipex/attention.py b/library/ipex/attention.py index d989ad53d..2bc62f65c 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -5,7 +5,7 @@ # pylint: disable=protected-access, missing-function-docstring, line-too-long -# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers +# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4)) attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) From 58fb64819ab117e2b7bca6e87bae28901b616860 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 9 Jun 2024 19:26:09 +0900 Subject: [PATCH 06/11] set static graph flag when DDP ref #1363 --- sdxl_train_control_net_lllite.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 301310901..5ff060a9f 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -289,6 +289,9 @@ def train(args): # acceleratorがなんかよろしくやってくれるらしい unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if isinstance(unet, DDP): + unet._set_static_graph() # avoid error for multiple use of the parameter + if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる else: From 1a104dc75ee5733af8ba17cc9778b39e26673734 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 9 Jun 2024 19:26:36 +0900 Subject: [PATCH 07/11] make forward/backward pathes same ref #1363 --- networks/control_net_lllite_for_train.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/networks/control_net_lllite_for_train.py b/networks/control_net_lllite_for_train.py index 65b3520cf..366451b7f 100644 --- a/networks/control_net_lllite_for_train.py +++ b/networks/control_net_lllite_for_train.py @@ -7,8 +7,10 @@ import torch from library import sdxl_original_unet from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) # input_blocksに適用するかどうか / if True, input_blocks are not applied @@ -103,19 +105,15 @@ def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplie add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim) self.cond_image = None - self.cond_emb = None def set_cond_image(self, cond_image): self.cond_image = cond_image - self.cond_emb = None def forward(self, x): if not self.enabled: return super().forward(x) - if self.cond_emb is None: - self.cond_emb = self.lllite_conditioning1(self.cond_image) - cx = self.cond_emb + cx = self.lllite_conditioning1(self.cond_image) # make forward and backward compatible # reshape / b,c,h,w -> b,h*w,c n, c, h, w = cx.shape @@ -159,9 +157,7 @@ def forward(self, x): # , cond_image=None): if not self.enabled: return super().forward(x) - if self.cond_emb is None: - self.cond_emb = self.lllite_conditioning1(self.cond_image) - cx = self.cond_emb + cx = self.lllite_conditioning1(self.cond_image) cx = torch.cat([cx, self.down(x)], dim=1) cx = self.mid(cx) From 18d7597b0b39cc2204dfbdfdcbf0fead97414be1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 11 Jun 2024 19:51:30 +0900 Subject: [PATCH 08/11] update README --- README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/README.md b/README.md index 52c963392..25aba6397 100644 --- a/README.md +++ b/README.md @@ -178,6 +178,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - The ControlNet training script `train_controlnet.py` for SD1.5/2.x was not working, but it has been fixed. PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) Thanks to sdbds! +- `train_network.py` and `sdxl_train_network.py` now restore the order/position of data loading from DataSet when resuming training. PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) Thanks to KohakuBlueleaf! + - This resolves the issue where the order of data loading from DataSet changes when resuming training. + - Specify the `--skip_until_initial_step` option to skip data loading until the specified step. If not specified, data loading starts from the beginning of the DataSet (same as before). + - If `--resume` is specified, the step saved in the state is used. + - Specify the `--initial_step` or `--initial_epoch` option to skip data loading until the specified step or epoch. Use these options in conjunction with `--skip_until_initial_step`. These options can be used without `--resume` (use them when resuming training with `--network_weights`). + - An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra! - It seems that the model file loading is faster in the WSL environment etc. - Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`. @@ -235,6 +241,12 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - SD1.5/2.x 用の ControlNet 学習スクリプト `train_controlnet.py` が動作しなくなっていたのが修正されました。PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) sdbds 氏に感謝します。 +- `train_network.py` および `sdxl_train_network.py` で、学習再開時に DataSet の読み込み順についても復元できるようになりました。PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) KohakuBlueleaf 氏に感謝します。 + - これにより、学習再開時に DataSet の読み込み順が変わってしまう問題が解消されます。 + - `--skip_until_initial_step` オプションを指定すると、指定したステップまで DataSet 読み込みをスキップします。指定しない場合の動作は変わりません(DataSet の最初から読み込みます) + - `--resume` オプションを指定すると、state に保存されたステップ数が使用されます。 + - `--initial_step` または `--initial_epoch` オプションを指定すると、指定したステップまたはエポックまで DataSet 読み込みをスキップします。これらのオプションは `--skip_until_initial_step` と併用してください。またこれらのオプションは `--resume` と併用しなくても使えます(`--network_weights` を用いた学習再開時などにお使いください )。 + - SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。 - WSL 環境等でモデルファイルの読み込みが高速化されるようです。 - `sdxl_train.py`、`sdxl_train_network.py`、`sdxl_train_textual_inversion.py`、`sdxl_train_control_net_lllite.py` で使用可能です。 From 56bb81c9e6483b8b4d5b83639548855b8359f4b4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 12 Jun 2024 21:39:35 +0900 Subject: [PATCH 09/11] add grad_hook after restore state closes #1344 --- sdxl_train.py | 46 +++++++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index 9e20c60ca..ae92d6a3d 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -481,6 +481,26 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): text_encoder2 = accelerator.prepare(text_encoder2) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + # TextEncoderの出力をキャッシュするときにはCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + text_encoder1.to("cpu", dtype=torch.float32) + text_encoder2.to("cpu", dtype=torch.float32) + clean_memory_on_device(accelerator.device) + else: + # make sure Text Encoders are on GPU + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + if args.fused_backward_pass: # use fused optimizer for backward pass: other optimizers will be supported in the future import library.adafactor_fused @@ -532,26 +552,6 @@ def optimizer_hook(parameter: torch.Tensor): parameter_optimizer_map[parameter] = opt_idx num_parameters_per_group[opt_idx] += 1 - # TextEncoderの出力をキャッシュするときにはCPUへ移動する - if args.cache_text_encoder_outputs: - # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 - text_encoder1.to("cpu", dtype=torch.float32) - text_encoder2.to("cpu", dtype=torch.float32) - clean_memory_on_device(accelerator.device) - else: - # make sure Text Encoders are on GPU - text_encoder1.to(accelerator.device) - text_encoder2.to(accelerator.device) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. - # -> But we think it's ok to patch accelerator even if deepspeed is enabled. - train_util.patch_accelerator_for_fp16_training(accelerator) - - # resumeする - train_util.resume_from_local_or_hf_if_specified(accelerator, args) - # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) @@ -589,7 +589,11 @@ def optimizer_hook(parameter: torch.Tensor): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) # For --sample_at_first sdxl_train_util.sample_images( From 25f961bc779bc79aef440813e3e8e92244ac5739 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 23 Jun 2024 13:24:30 +0900 Subject: [PATCH 10/11] fix to work cache_latents/text_encoder_outputs --- README.md | 6 ++++++ tools/cache_latents.py | 5 ++++- tools/cache_text_encoder_outputs.py | 5 ++++- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a7047a360..fd81a781f 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### Jun 23, 2024 / 2024-06-23: + +- Fixed `cache_latents.py` and `cache_text_encoder_outputs.py` not working. (Will be included in the next release.) + +- `cache_latents.py` および `cache_text_encoder_outputs.py` が動作しなくなっていたのを修正しました。(次回リリースに含まれます。) + ### Apr 7, 2024 / 2024-04-07: v0.8.7 - The default value of `huber_schedule` in Scheduled Huber Loss is changed from `exponential` to `snr`, which is expected to give better results. diff --git a/tools/cache_latents.py b/tools/cache_latents.py index 347db27f7..32101de3f 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -16,12 +16,13 @@ ConfigSanitizer, BlueprintGenerator, ) -from library.utils import setup_logging +from library.utils import setup_logging, add_logging_arguments setup_logging() import logging logger = logging.getLogger(__name__) def cache_to_disk(args: argparse.Namespace) -> None: + setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) # check cache latents arg @@ -94,6 +95,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: # acceleratorを準備する logger.info("prepare accelerator") + args.deepspeed = False accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -170,6 +172,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index 5f1d6d201..a75d9da74 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -16,12 +16,13 @@ ConfigSanitizer, BlueprintGenerator, ) -from library.utils import setup_logging +from library.utils import setup_logging, add_logging_arguments setup_logging() import logging logger = logging.getLogger(__name__) def cache_to_disk(args: argparse.Namespace) -> None: + setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) # check cache arg @@ -99,6 +100,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: # acceleratorを準備する logger.info("prepare accelerator") + args.deepspeed = False accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -171,6 +173,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) From 0b3e4f7ab62b7c93e66972b7bd2774b8fe679792 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 25 Jun 2024 20:03:09 +0900 Subject: [PATCH 11/11] show file name if error in load_image ref #1385 --- library/train_util.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 4736ff4ff..760be33eb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2434,16 +2434,20 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: return train_dataset_group -def load_image(image_path, alpha=False): - image = Image.open(image_path) - if alpha: - if not image.mode == "RGBA": - image = image.convert("RGBA") - else: - if not image.mode == "RGB": - image = image.convert("RGB") - img = np.array(image, np.uint8) - return img +def load_image(image_path, alpha=False): + try: + with Image.open(image_path) as image: + if alpha: + if not image.mode == "RGBA": + image = image.convert("RGBA") + else: + if not image.mode == "RGB": + image = image.convert("RGB") + img = np.array(image, np.uint8) + return img + except (IOError, OSError) as e: + logger.error(f"Error loading file: {image_path}") + raise e # 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom)