From deb5e930ae1d873a4648c36ce54a4d72194a3529 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Mon, 29 Jan 2024 10:49:39 +0100 Subject: [PATCH] fix exclusions for Downsample and Upsample (also simplify list comprehension for exclusion list) --- src/refiners/foundationals/latent_diffusion/lora.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index a8812847e..5953a287c 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -67,11 +67,8 @@ def add_loras_to_text_encoder(self, loras: dict[str, Lora], /) -> None: def add_loras_to_unet(self, loras: dict[str, Lora], /) -> None: unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key} - exclude: list[str] = [] exclude = [ - self.unet_exclusions[exclusion] - for exclusion in self.unet_exclusions - if all([exclusion not in key for key in unet_loras.keys()]) + block for s, block in self.unet_exclusions.items() if all([s not in key for key in unet_loras.keys()]) ] SDLoraManager.auto_attach(unet_loras, self.unet, exclude=exclude) @@ -121,8 +118,8 @@ def unet_exclusions(self) -> dict[str, str]: return { "time": "TimestepEncoder", "res": "ResidualBlock", - "downsample": "DownsampleBlock", - "upsample": "UpsampleBlock", + "downsample": "Downsample", + "upsample": "Upsample", } @property