Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] enable lora for sdxl adapters too and add slow tests. #5555

Merged
merged 6 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import inspect
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -1067,3 +1068,77 @@ def __call__(
return (image,)

return StableDiffusionXLPipelineOutput(images=image)


# Overrride to properly handle the loading and unloading of the additional text encoder.
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
**kwargs,
)
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)

text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
)

text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
if len(text_encoder_2_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_2_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder_2,
prefix="text_encoder_2",
lora_scale=self.lora_scale,
)

@classmethod
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
def save_lora_weights(
self,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
state_dict = {}

def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict

state_dict.update(pack_weights(unet_lora_layers, "unet"))

if text_encoder_lora_layers and text_encoder_2_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))

self.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)

# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch
def _remove_text_encoder_monkey_patch(self):
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import random
import gc
import unittest

import numpy as np
Expand All @@ -29,10 +30,13 @@
StableDiffusionXLAdapterPipeline,
T2IAdapter,
UNet2DConditionModel,
EulerAncestralDiscreteScheduler,

)
from diffusers.utils import logging
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device

from diffusers.utils import load_image, randn_tensor, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow
ilisparrow marked this conversation as resolved.
Show resolved Hide resolved
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
Expand Down Expand Up @@ -560,3 +564,64 @@ def test_inference_batch_single_identical(

if test_mean_pixel_difference:
assert_mean_pixel_difference(output_batch[0][0], output[0][0])


@slow
@require_torch_gpu
class AdapterSDXLPipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()

def test_canny(self):
adapter = T2IAdapter.from_pretrained(
"TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16
).to("cpu")
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
'stabilityai/stable-diffusion-xl-base-1.0', adapter=adapter, torch_dtype=torch.float16, variant="fp16",
)
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors")
pipe.enable_sequential_cpu_offload()
pipe.set_progress_bar_config(disable=None)

generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "toy"
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png"
)

images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images

assert images[0].shape == (768, 512, 3)

original_image = images[0, -3:, -3:, -1].flatten()
assert numpy_cosine_similarity_distance(original_image, expected_image) < 1e-4
assert np.allclose(original_image, expected_image, atol=1e-04)


def test_canny_lora(self):
adapter = T2IAdapter.from_pretrained(
"TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16
).to("cpu")
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
'stabilityai/stable-diffusion-xl-base-1.0', adapter=adapter, torch_dtype=torch.float16, variant="fp16",
)
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors")
pipe.enable_sequential_cpu_offload()
pipe.set_progress_bar_config(disable=None)

generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "toy"
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png"
)

images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images

assert images[0].shape == (768, 512, 3)

original_image = images[0, -3:, -3:, -1].flatten()
expected_image = np.array([0.50346327, 0.50708383, 0.50719553, 0.5135172, 0.5155377, 0.5066059, 0.49680984, 0.5005894, 0.48509413])
assert numpy_cosine_similarity_distance(original_image, expected_image) < 1e-4

Loading