Skip to content

Commit

Permalink
[Core] enable lora for sdxl adapters too and add slow tests. (huggin…
Browse files Browse the repository at this point in the history
…gface#5555)

* Enable lora for sdxl adapters too.

Issue huggingface#5516

* fix: assertion values.

* Use numpy_cosine_similarity_distance on the arrays

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* Use numpy_cosine_similarity_distance on the arrays

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* Changed imports orders to pass tests

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

---------

Co-authored-by: Ilias A <iliasamri00@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
  • Loading branch information
4 people authored and kashif committed Nov 11, 2023
1 parent 96c1b9e commit 86dd16c
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import inspect
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -1066,3 +1067,77 @@ def __call__(
return (image,)

return StableDiffusionXLPipelineOutput(images=image)


# Overrride to properly handle the loading and unloading of the additional text encoder.
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
**kwargs,
)
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)

text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
)

text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
if len(text_encoder_2_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_2_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder_2,
prefix="text_encoder_2",
lora_scale=self.lora_scale,
)

@classmethod
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
def save_lora_weights(
self,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
state_dict = {}

def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict

state_dict.update(pack_weights(unet_lora_layers, "unet"))

if text_encoder_lora_layers and text_encoder_2_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))

self.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)

# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch
def _remove_text_encoder_monkey_patch(self):
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import random
import gc
import unittest

import numpy as np
Expand All @@ -29,10 +30,14 @@
StableDiffusionXLAdapterPipeline,
T2IAdapter,
UNet2DConditionModel,
EulerAncestralDiscreteScheduler,

)
from diffusers.utils import logging
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device

from diffusers.utils import load_image
from diffusers.utils.torch_utils import randn_tensor
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
Expand Down Expand Up @@ -560,3 +565,64 @@ def test_inference_batch_single_identical(

if test_mean_pixel_difference:
assert_mean_pixel_difference(output_batch[0][0], output[0][0])


@slow
@require_torch_gpu
class AdapterSDXLPipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()

def test_canny(self):
adapter = T2IAdapter.from_pretrained(
"TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16
).to("cpu")
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
'stabilityai/stable-diffusion-xl-base-1.0', adapter=adapter, torch_dtype=torch.float16, variant="fp16",
)
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors")
pipe.enable_sequential_cpu_offload()
pipe.set_progress_bar_config(disable=None)

generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "toy"
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png"
)

images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images

assert images[0].shape == (768, 512, 3)

original_image = images[0, -3:, -3:, -1].flatten()
assert numpy_cosine_similarity_distance(original_image, expected_image) < 1e-4
assert np.allclose(original_image, expected_image, atol=1e-04)


def test_canny_lora(self):
adapter = T2IAdapter.from_pretrained(
"TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16
).to("cpu")
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
'stabilityai/stable-diffusion-xl-base-1.0', adapter=adapter, torch_dtype=torch.float16, variant="fp16",
)
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors")
pipe.enable_sequential_cpu_offload()
pipe.set_progress_bar_config(disable=None)

generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "toy"
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png"
)

images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images

assert images[0].shape == (768, 512, 3)

original_image = images[0, -3:, -3:, -1].flatten()
expected_image = np.array([0.50346327, 0.50708383, 0.50719553, 0.5135172, 0.5155377, 0.5066059, 0.49680984, 0.5005894, 0.48509413])
assert numpy_cosine_similarity_distance(original_image, expected_image) < 1e-4

0 comments on commit 86dd16c

Please sign in to comment.