From 19e38dfab03a3a9947ee52a58b5b9712e5037f5c Mon Sep 17 00:00:00 2001 From: alex choi Date: Sat, 15 Jun 2024 16:39:35 +0000 Subject: [PATCH] add load tokenizers from pretrained_model_name_or_path if available --- library/sdxl_train_util.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index a29013e34..be6bb1f68 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -133,6 +133,20 @@ def _load_target_model( def load_tokenizers(args: argparse.Namespace): logger.info("prepare tokenizers") + # load diffusers tokenizers if available + name_or_path = args.pretrained_model_name_or_path + if os.path.isdir(name_or_path): + tokenizer_path = os.path.join(name_or_path, "tokenizer") + tokenizer_2_path = os.path.join(name_or_path, "tokenizer_2") + if os.path.exists(tokenizer_path) \ + and os.path.exists(tokenizer_2_path): + logger.info(f"load tokenizers from pretrained_model_name_or_path: {name_or_path}") + tokeniers = [ + CLIPTokenizer.from_pretrained(tokenizer_path), + CLIPTokenizer.from_pretrained(tokenizer_2_path), + ] + return tokeniers + original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH] tokeniers = [] for i, original_path in enumerate(original_paths):