From 500a202d682cc23d2f3342e5fad4f552986c6ef2 Mon Sep 17 00:00:00 2001 From: alex choi Date: Fri, 14 Jun 2024 00:27:14 +0000 Subject: [PATCH 1/4] fix deepspeed args in cache_latents --- tools/cache_latents.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tools/cache_latents.py b/tools/cache_latents.py index 347db27f7..f391684e3 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -12,6 +12,7 @@ from library import config_util from library import train_util from library import sdxl_train_util +from library import deepspeed_utils from library.config_util import ( ConfigSanitizer, BlueprintGenerator, @@ -173,6 +174,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) + deepspeed_utils.add_deepspeed_arguments(parser) config_util.add_config_arguments(parser) parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") parser.add_argument( From 99c4371df92b1cf965d8d76bf1296c33b91ef6d4 Mon Sep 17 00:00:00 2001 From: alex choi Date: Fri, 14 Jun 2024 00:36:26 +0000 Subject: [PATCH 2/4] add deepspeed args to cache_text_encoder_outputs --- tools/cache_text_encoder_outputs.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index 5f1d6d201..b7d959262 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -11,6 +11,7 @@ from library import config_util from library import train_util +from library import deepspeed_utils from library import sdxl_train_util from library.config_util import ( ConfigSanitizer, @@ -175,6 +176,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) config_util.add_config_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) sdxl_train_util.add_sdxl_training_arguments(parser) parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") parser.add_argument( From c72a9d0fcae02d604a40e71a0839a0fbf4d78f36 Mon Sep 17 00:00:00 2001 From: alex choi Date: Fri, 14 Jun 2024 19:53:00 +0000 Subject: [PATCH 3/4] validate text encoder cache instead of just checking file exists. --- library/train_util.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3cc..f0e73743b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1035,7 +1035,7 @@ def cache_text_encoder_outputs( logger.info("caching text encoder outputs.") image_infos = list(self.image_data.values()) - logger.info("checking cache existence...") + logger.info("checking cache validity...") image_infos_to_cache = [] for info in tqdm(image_infos): # subset = self.image_to_subset[info.image_key] @@ -1046,7 +1046,9 @@ def cache_text_encoder_outputs( if not is_main_process: # store to info only continue - if os.path.exists(te_out_npz): + is_cache_valid = is_disk_cached_text_encoder_output_valid(te_out_npz) + + if is_cache_valid: continue image_infos_to_cache.append(info) @@ -2138,6 +2140,23 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): return True +def is_disk_cached_text_encoder_output_valid(npz_path: str): + + if not os.path.exists(npz_path): + logger.debug(f'is_disk_cached_text_encoder_output_valid file not found: {npz_path}') + return False + + try: + hidden_state1, hidden_state2, pool2 = load_text_encoder_outputs_from_disk(npz_path) + if hidden_state1 is None or hidden_state2 is None or pool2 is None: + logger.debug(f'is_disk_cached_text_encoder_output_valid None value found: hidden_state1 {hidden_state1}, hidden_state2 {hidden_state2}, pool2 {pool2}') + return False + return True + except Exception as e: + logger.debug(f"is_disk_cached_text_encoder_output_valid failed to load text encoder outputs from {npz_path}. {e}") + return False + + # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) def load_latents_from_disk( npz_path, From 6789561e0c81714eaf74acad443e54c5c456611b Mon Sep 17 00:00:00 2001 From: alex choi Date: Sat, 15 Jun 2024 16:39:35 +0000 Subject: [PATCH 4/4] add load tokenizers from pretrained_model_name_or_path if available --- library/sdxl_train_util.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index a29013e34..be6bb1f68 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -133,6 +133,20 @@ def _load_target_model( def load_tokenizers(args: argparse.Namespace): logger.info("prepare tokenizers") + # load diffusers tokenizers if available + name_or_path = args.pretrained_model_name_or_path + if os.path.isdir(name_or_path): + tokenizer_path = os.path.join(name_or_path, "tokenizer") + tokenizer_2_path = os.path.join(name_or_path, "tokenizer_2") + if os.path.exists(tokenizer_path) \ + and os.path.exists(tokenizer_2_path): + logger.info(f"load tokenizers from pretrained_model_name_or_path: {name_or_path}") + tokeniers = [ + CLIPTokenizer.from_pretrained(tokenizer_path), + CLIPTokenizer.from_pretrained(tokenizer_2_path), + ] + return tokeniers + original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH] tokeniers = [] for i, original_path in enumerate(original_paths):