Skip to content

Commit

Permalink
Merge branch 'main' into inpaint-docs
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul authored Sep 29, 2023
2 parents 389dee5 + 78a7851 commit dbf63b1
Show file tree
Hide file tree
Showing 25 changed files with 97 additions and 51 deletions.
2 changes: 1 addition & 1 deletion examples/t2i_adapter/README_sdxl.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pip install -e .

Then cd in the `examples/t2i_adapter` folder and run
```bash
pip install -r requirements_sdxl.txt
pip install -r requirements.txt
```

And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2165,7 +2165,7 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True
self.unet.unfuse_lora()

if self.use_peft_backend:
from peft.tuners.tuner_utils import BaseTunerLayer
from peft.tuners.tuners_utils import BaseTunerLayer

def unfuse_text_encoder_lora(text_encoder):
for module in text_encoder.modules():
Expand Down
25 changes: 23 additions & 2 deletions src/diffusers/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,9 @@ class FlaxBasicTransformerBlock(nn.Module):
Parameters `dtype`
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
"""
dim: int
n_heads: int
Expand All @@ -266,15 +269,28 @@ class FlaxBasicTransformerBlock(nn.Module):
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32
use_memory_efficient_attention: bool = False
split_head_dim: bool = False

def setup(self):
# self attention (or cross_attention if only_cross_attention is True)
self.attn1 = FlaxAttention(
self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
self.dim,
self.n_heads,
self.d_head,
self.dropout,
self.use_memory_efficient_attention,
self.split_head_dim,
dtype=self.dtype,
)
# cross attention
self.attn2 = FlaxAttention(
self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
self.dim,
self.n_heads,
self.d_head,
self.dropout,
self.use_memory_efficient_attention,
self.split_head_dim,
dtype=self.dtype,
)
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
Expand Down Expand Up @@ -327,6 +343,9 @@ class FlaxTransformer2DModel(nn.Module):
Parameters `dtype`
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
"""
in_channels: int
n_heads: int
Expand All @@ -337,6 +356,7 @@ class FlaxTransformer2DModel(nn.Module):
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32
use_memory_efficient_attention: bool = False
split_head_dim: bool = False

def setup(self):
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
Expand All @@ -362,6 +382,7 @@ def setup(self):
only_cross_attention=self.only_cross_attention,
dtype=self.dtype,
use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
)
for _ in range(self.depth)
]
Expand Down
15 changes: 15 additions & 0 deletions src/diffusers/models/unet_2d_blocks_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
Whether to add downsampling layer before each final output
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
Expand All @@ -51,6 +54,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
use_linear_projection: bool = False
only_cross_attention: bool = False
use_memory_efficient_attention: bool = False
split_head_dim: bool = False
dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1

Expand All @@ -77,6 +81,7 @@ def setup(self):
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
dtype=self.dtype,
)
attentions.append(attn_block)
Expand Down Expand Up @@ -179,6 +184,9 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
Whether to add upsampling layer before each final output
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
Expand All @@ -192,6 +200,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
use_linear_projection: bool = False
only_cross_attention: bool = False
use_memory_efficient_attention: bool = False
split_head_dim: bool = False
dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1

Expand Down Expand Up @@ -219,6 +228,7 @@ def setup(self):
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
dtype=self.dtype,
)
attentions.append(attn_block)
Expand Down Expand Up @@ -323,6 +333,9 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
Number of attention heads of each spatial transformer block
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
Expand All @@ -332,6 +345,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
num_attention_heads: int = 1
use_linear_projection: bool = False
use_memory_efficient_attention: bool = False
split_head_dim: bool = False
dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1

Expand All @@ -356,6 +370,7 @@ def setup(self):
depth=self.transformer_layers_per_block,
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
dtype=self.dtype,
)
attentions.append(attn_block)
Expand Down
7 changes: 7 additions & 0 deletions src/diffusers/models/unet_2d_condition_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682).
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
"""

sample_size: int = 32
Expand All @@ -116,6 +119,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
flip_sin_to_cos: bool = True
freq_shift: int = 0
use_memory_efficient_attention: bool = False
split_head_dim: bool = False
transformer_layers_per_block: Union[int, Tuple[int]] = 1
addition_embed_type: Optional[str] = None
addition_time_embed_dim: Optional[int] = None
Expand Down Expand Up @@ -231,6 +235,7 @@ def setup(self):
use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i],
use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
dtype=self.dtype,
)
else:
Expand All @@ -254,6 +259,7 @@ def setup(self):
transformer_layers_per_block=transformer_layers_per_block[-1],
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
dtype=self.dtype,
)

Expand Down Expand Up @@ -284,6 +290,7 @@ def setup(self):
use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i],
use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
dtype=self.dtype,
)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def __call__(
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample

# adjust latents with inverse of vae scale
latents = latents / self.vqvae.config.scaling_factor
# decode the image latents with the VAE
image = self.vqvae.decode(latents).sample

Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/pipelines/pipeline_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
<Tip>
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
`huggingface-cli login`. You can also activate the special
[“offline-mode”](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
firewalled environment.
`huggingface-cli login`.
</Tip>
Expand Down Expand Up @@ -323,6 +321,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
revision = kwargs.pop("revision", None)
from_pt = kwargs.pop("from_pt", False)
use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False)
split_head_dim = kwargs.pop("split_head_dim", False)
dtype = kwargs.pop("dtype", None)

# 1. Download the checkpoints and configs
Expand Down Expand Up @@ -501,6 +500,7 @@ def load_module(name, value):
loadable_folder,
from_pt=from_pt,
use_memory_efficient_attention=use_memory_efficient_attention,
split_head_dim=split_head_dim,
dtype=dtype,
)
params[name] = loaded_params
Expand Down
8 changes: 4 additions & 4 deletions src/diffusers/schedulers/scheduling_ddim_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,6 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
)

# Roll timesteps array by one to reflect reversed origin and destination semantics for each step
timesteps = np.roll(timesteps, 1)
timesteps[0] = int(timesteps[1] - step_ratio)
self.timesteps = torch.from_numpy(timesteps).to(device)

def step(
Expand Down Expand Up @@ -335,7 +332,10 @@ def step(
"""
# 1. get previous step value (=t+1)
prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
prev_timestep = timestep
timestep = min(
timestep - self.config.num_train_timesteps // self.num_inference_steps, self.num_train_timesteps - 1
)

# 2. compute alphas, betas
# change original implementation to exactly match noise levels for analogous forward process
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/utils/state_dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class StateDictType(enum.Enum):
".v_proj.lora_linear_layer.down": ".v_proj.lora_A",
".out_proj.lora_linear_layer.up": ".out_proj.lora_B",
".out_proj.lora_linear_layer.down": ".out_proj.lora_A",
".lora_linear_layer.up": ".lora_B",
".lora_linear_layer.down": ".lora_A",
}

DIFFUSERS_OLD_TO_PEFT = {
Expand Down
4 changes: 2 additions & 2 deletions tests/pipelines/audio_diffusion/test_audio_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
UNet2DConditionModel,
UNet2DModel,
)
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, slow, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device


enable_full_determinism()
Expand Down Expand Up @@ -95,7 +95,7 @@ def dummy_vqvae_and_unet(self):
)
return vqvae, unet

@slow
@nightly
def test_audio_diffusion(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
mel = Mel(
Expand Down
4 changes: 2 additions & 2 deletions tests/pipelines/audioldm/test_audioldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
UNet2DConditionModel,
)
from diffusers.utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, nightly, slow, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, nightly, torch_device

from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
Expand Down Expand Up @@ -369,7 +369,7 @@ def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)


@slow
@nightly
class AudioLDMPipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
floats_tensor,
load_image,
load_numpy,
nightly,
require_torch_gpu,
slow,
torch_device,
)

Expand Down Expand Up @@ -232,7 +232,7 @@ def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=2e-1)


@slow
@nightly
@require_torch_gpu
class KandinskyV22ControlnetImg2ImgPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from transformers import CLIPTextConfig, CLIPTextModel

from diffusers import DDIMScheduler, LDMPipeline, UNet2DModel, VQModel
from diffusers.utils.testing_utils import enable_full_determinism, require_torch, slow, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch, torch_device


enable_full_determinism()
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_inference_uncond(self):
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance


@slow
@nightly
@require_torch
class LDMPipelineIntegrationTests(unittest.TestCase):
def test_inference_uncond(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
import torch

from diffusers import StableDiffusionKDiffusionPipeline
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device


enable_full_determinism()


@slow
@nightly
@require_torch_gpu
class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
StableDiffusionModelEditingPipeline,
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, skip_mps, slow, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, skip_mps, torch_device

from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
Expand Down Expand Up @@ -184,7 +184,7 @@ def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=5e-3)


@slow
@nightly
@require_torch_gpu
class StableDiffusionModelEditingSlowTests(unittest.TestCase):
def tearDown(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
)
from diffusers.utils.testing_utils import (
enable_full_determinism,
nightly,
require_torch_gpu,
slow,
torch_device,
)

Expand Down Expand Up @@ -188,7 +188,7 @@ def test_stable_diffusion_paradigms_negative_prompt(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2


@slow
@nightly
@require_torch_gpu
class StableDiffusionParadigmsPipelineSlowTests(unittest.TestCase):
def tearDown(self):
Expand Down
Loading

0 comments on commit dbf63b1

Please sign in to comment.