diff --git a/library/lpw_stable_diffusion.py b/library/lpw_stable_diffusion.py index 9dce91a76..3963e9b15 100644 --- a/library/lpw_stable_diffusion.py +++ b/library/lpw_stable_diffusion.py @@ -9,7 +9,7 @@ import PIL.Image import torch from packaging import version -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection import diffusers from diffusers import SchedulerMixin, StableDiffusionPipeline @@ -520,6 +520,7 @@ def __init__( safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, + image_encoder: CLIPVisionModelWithProjection = None, clip_skip: int = 1, ): super().__init__( @@ -531,32 +532,11 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, requires_safety_checker=requires_safety_checker, + image_encoder=image_encoder, ) - self.clip_skip = clip_skip + self.custom_clip_skip = clip_skip self.__init__additional__() - # else: - # def __init__( - # self, - # vae: AutoencoderKL, - # text_encoder: CLIPTextModel, - # tokenizer: CLIPTokenizer, - # unet: UNet2DConditionModel, - # scheduler: SchedulerMixin, - # safety_checker: StableDiffusionSafetyChecker, - # feature_extractor: CLIPFeatureExtractor, - # ): - # super().__init__( - # vae=vae, - # text_encoder=text_encoder, - # tokenizer=tokenizer, - # unet=unet, - # scheduler=scheduler, - # safety_checker=safety_checker, - # feature_extractor=feature_extractor, - # ) - # self.__init__additional__() - def __init__additional__(self): if not hasattr(self, "vae_scale_factor"): setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1)) @@ -624,7 +604,7 @@ def _encode_prompt( prompt=prompt, uncond_prompt=negative_prompt if do_classifier_free_guidance else None, max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, + clip_skip=self.custom_clip_skip, ) bs_embed, seq_len, _ = text_embeddings.shape text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) diff --git a/library/model_util.py b/library/model_util.py index a577b97d4..6102d0a18 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -4,10 +4,13 @@ import math import os import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -571,9 +574,9 @@ def convert_ldm_clip_checkpoint_v1(checkpoint): if key.startswith("cond_stage_model.transformer"): text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] - # support checkpoint without position_ids (invalid checkpoint) - if "text_model.embeddings.position_ids" not in text_model_dict: - text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text + # remove position_ids for newer transformer, which causes error :( + if "text_model.embeddings.position_ids" in text_model_dict: + text_model_dict.pop("text_model.embeddings.position_ids") return text_model_dict diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index a844927cd..08b90c393 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -100,7 +100,7 @@ def convert_key(key): key = key.replace(".ln_final", ".final_layer_norm") # ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids elif ".embeddings.position_ids" in key: - key = None # remove this key: make position_ids by ourselves + key = None # remove this key: position_ids is not used in newer transformers return key keys = list(checkpoint.keys()) @@ -126,10 +126,6 @@ def convert_key(key): new_sd[key_pfx + "k_proj" + key_suffix] = values[1] new_sd[key_pfx + "v_proj" + key_suffix] = values[2] - # original SD にはないので、position_idsを追加 - position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) - new_sd["text_model.embeddings.position_ids"] = position_ids - # logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None) @@ -265,9 +261,9 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty elif k.startswith("conditioner.embedders.1.model."): te2_sd[k] = state_dict.pop(k) - # 一部のposition_idsがないモデルへの対応 / add position_ids for some models - if "text_model.embeddings.position_ids" not in te1_sd: - te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) + # 最新の transformers では position_ids を含むとエラーになるので削除 / remove position_ids for latest transformers + if "text_model.embeddings.position_ids" in te1_sd: + te1_sd.pop("text_model.embeddings.position_ids") info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32 print("text encoder 1:", info1) diff --git a/requirements.txt b/requirements.txt index 0a80d70d7..8517d95ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ -accelerate==0.23.0 -transformers==4.30.2 -diffusers[torch]==0.21.2 +accelerate==0.25.0 +transformers==4.36.2 +diffusers[torch]==0.25.0 ftfy==6.1.1 # albumentations==1.3.0 opencv-python==4.7.0.68 @@ -14,7 +14,7 @@ altair==4.2.2 easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 -huggingface-hub==0.15.1 +huggingface-hub==0.20.1 # for BLIP captioning # requests==2.28.2 # timm==0.6.12