Skip to content

Commit

Permalink
Merge branch 'main' into dreambooth-lora-flux-exploration
Browse files Browse the repository at this point in the history
  • Loading branch information
linoytsaban authored Oct 1, 2024
2 parents 57fb65b + 33fafe3 commit 4faa6cf
Show file tree
Hide file tree
Showing 23 changed files with 2,426 additions and 31 deletions.
3 changes: 3 additions & 0 deletions docs/source/en/api/pipelines/pag.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial

## StableDiffusionControlNetPAGPipeline
[[autodoc]] StableDiffusionControlNetPAGPipeline

## StableDiffusionControlNetPAGInpaintPipeline
[[autodoc]] StableDiffusionControlNetPAGInpaintPipeline
- all
- __call__

Expand Down
1 change: 1 addition & 0 deletions docs/source/en/api/schedulers/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Many schedulers are implemented from the [k-diffusion](https://github.com/crowso
| sgm_uniform | init with `timestep_spacing="trailing"` |
| simple | init with `timestep_spacing="trailing"` |
| exponential | init with `timestep_spacing="linspace"`, `use_exponential_sigmas=True` |
| beta | init with `timestep_spacing="linspace"`, `use_beta_sigmas=True` |

All schedulers are built from the base [`SchedulerMixin`] class which implements low level utilities shared by all schedulers.

Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@
"StableDiffusionAttendAndExcitePipeline",
"StableDiffusionControlNetImg2ImgPipeline",
"StableDiffusionControlNetInpaintPipeline",
"StableDiffusionControlNetPAGInpaintPipeline",
"StableDiffusionControlNetPAGPipeline",
"StableDiffusionControlNetPipeline",
"StableDiffusionControlNetXSPipeline",
Expand Down Expand Up @@ -778,6 +779,7 @@
StableDiffusionAttendAndExcitePipeline,
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPAGInpaintPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionControlNetXSPipeline,
Expand Down
41 changes: 39 additions & 2 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,10 +516,47 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
f"transformer.single_transformer_blocks.{i}.norm.linear",
)

remaining_keys = list(sds_sd.keys())
te_state_dict = {}
if remaining_keys:
if not all(k.startswith("lora_te1") for k in remaining_keys):
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
for key in remaining_keys:
if not key.endswith("lora_down.weight"):
continue

lora_name = key.split(".")[0]
lora_name_up = f"{lora_name}.lora_up.weight"
lora_name_alpha = f"{lora_name}.alpha"
diffusers_name = _convert_text_encoder_lora_key(key, lora_name)

if lora_name.startswith(("lora_te_", "lora_te1_")):
down_weight = sds_sd.pop(key)
sd_lora_rank = down_weight.shape[0]
te_state_dict[diffusers_name] = down_weight
te_state_dict[diffusers_name.replace(".down.", ".up.")] = sds_sd.pop(lora_name_up)

if lora_name_alpha in sds_sd:
alpha = sds_sd.pop(lora_name_alpha).item()
scale = alpha / sd_lora_rank

scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2

te_state_dict[diffusers_name] *= scale_down
te_state_dict[diffusers_name.replace(".down.", ".up.")] *= scale_up

if len(sds_sd) > 0:
logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}")
logger.warning(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}")

if te_state_dict:
te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()}

return ait_sd
new_state_dict = {**ait_sd, **te_state_dict}
return new_state_dict

return _convert_sd_scripts_to_ai_toolkit(state_dict)

Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@
)
_import_structure["pag"].extend(
[
"StableDiffusionControlNetPAGInpaintPipeline",
"AnimateDiffPAGPipeline",
"KolorsPAGPipeline",
"HunyuanDiTPAGPipeline",
Expand Down Expand Up @@ -566,6 +567,7 @@
KolorsPAGPipeline,
PixArtSigmaPAGPipeline,
StableDiffusion3PAGPipeline,
StableDiffusionControlNetPAGInpaintPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGImg2ImgPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/auto_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline,
StableDiffusion3PAGPipeline,
StableDiffusionControlNetPAGInpaintPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGImg2ImgPipeline,
Expand Down Expand Up @@ -148,6 +149,7 @@
("kandinsky", KandinskyInpaintCombinedPipeline),
("kandinsky22", KandinskyV22InpaintCombinedPipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline),
("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGInpaintPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
("flux", FluxInpaintPipeline),
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/pag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
_import_structure["pipeline_pag_controlnet_sd_inpaint"] = ["StableDiffusionControlNetPAGInpaintPipeline"]
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
_import_structure["pipeline_pag_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetPAGImg2ImgPipeline"]
_import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"]
Expand All @@ -44,6 +45,7 @@
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
from .pipeline_pag_controlnet_sd_inpaint import StableDiffusionControlNetPAGInpaintPipeline
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
from .pipeline_pag_controlnet_sd_xl_img2img import StableDiffusionXLControlNetPAGImg2ImgPipeline
from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline
Expand Down
Loading

0 comments on commit 4faa6cf

Please sign in to comment.