Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate text encoder cache + add deepspeed arg parsing #1372

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions library/sdxl_train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,20 @@ def _load_target_model(
def load_tokenizers(args: argparse.Namespace):
logger.info("prepare tokenizers")

# load diffusers tokenizers if available
name_or_path = args.pretrained_model_name_or_path
if os.path.isdir(name_or_path):
tokenizer_path = os.path.join(name_or_path, "tokenizer")
tokenizer_2_path = os.path.join(name_or_path, "tokenizer_2")
if os.path.exists(tokenizer_path) \
and os.path.exists(tokenizer_2_path):
logger.info(f"load tokenizers from pretrained_model_name_or_path: {name_or_path}")
tokeniers = [
CLIPTokenizer.from_pretrained(tokenizer_path),
CLIPTokenizer.from_pretrained(tokenizer_2_path),
]
return tokeniers

original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
tokeniers = []
for i, original_path in enumerate(original_paths):
Expand Down
23 changes: 21 additions & 2 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,7 @@ def cache_text_encoder_outputs(
logger.info("caching text encoder outputs.")
image_infos = list(self.image_data.values())

logger.info("checking cache existence...")
logger.info("checking cache validity...")
image_infos_to_cache = []
for info in tqdm(image_infos):
# subset = self.image_to_subset[info.image_key]
Expand All @@ -1046,7 +1046,9 @@ def cache_text_encoder_outputs(
if not is_main_process: # store to info only
continue

if os.path.exists(te_out_npz):
is_cache_valid = is_disk_cached_text_encoder_output_valid(te_out_npz)

if is_cache_valid:
continue

image_infos_to_cache.append(info)
Expand Down Expand Up @@ -2138,6 +2140,23 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
return True


def is_disk_cached_text_encoder_output_valid(npz_path: str):

if not os.path.exists(npz_path):
logger.debug(f'is_disk_cached_text_encoder_output_valid file not found: {npz_path}')
return False

try:
hidden_state1, hidden_state2, pool2 = load_text_encoder_outputs_from_disk(npz_path)
if hidden_state1 is None or hidden_state2 is None or pool2 is None:
logger.debug(f'is_disk_cached_text_encoder_output_valid None value found: hidden_state1 {hidden_state1}, hidden_state2 {hidden_state2}, pool2 {pool2}')
return False
return True
except Exception as e:
logger.debug(f"is_disk_cached_text_encoder_output_valid failed to load text encoder outputs from {npz_path}. {e}")
return False


# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
def load_latents_from_disk(
npz_path,
Expand Down
2 changes: 2 additions & 0 deletions tools/cache_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from library import config_util
from library import train_util
from library import sdxl_train_util
from library import deepspeed_utils
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
Expand Down Expand Up @@ -173,6 +174,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
deepspeed_utils.add_deepspeed_arguments(parser)
config_util.add_config_arguments(parser)
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
parser.add_argument(
Expand Down
2 changes: 2 additions & 0 deletions tools/cache_text_encoder_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from library import config_util
from library import train_util
from library import deepspeed_utils
from library import sdxl_train_util
from library.config_util import (
ConfigSanitizer,
Expand Down Expand Up @@ -175,6 +176,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
config_util.add_config_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser)
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
parser.add_argument(
Expand Down