diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 84acaef2fbb6..41fce1706e20 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -188,6 +188,8 @@
title: UNet2DConditionModel
- local: api/models/unet3d-cond
title: UNet3DConditionModel
+ - local: api/models/unet-motion
+ title: UNetMotionModel
- local: api/models/vq
title: VQModel
- local: api/models/autoencoderkl
@@ -210,6 +212,8 @@
title: Overview
- local: api/pipelines/alt_diffusion
title: AltDiffusion
+ - local: api/pipelines/animatediff
+ title: AnimateDiff
- local: api/pipelines/attend_and_excite
title: Attend-and-Excite
- local: api/pipelines/audio_diffusion
@@ -396,5 +400,5 @@
title: Utilities
- local: api/image_processor
title: VAE Image Processor
- title: Internal classes
+ title: Internal classes
title: API
diff --git a/docs/source/en/api/models/unet-motion.md b/docs/source/en/api/models/unet-motion.md
new file mode 100644
index 000000000000..07d4df64c35f
--- /dev/null
+++ b/docs/source/en/api/models/unet-motion.md
@@ -0,0 +1,13 @@
+# UNetMotionModel
+
+The [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 2D UNet model.
+
+The abstract from the paper is:
+
+*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.*
+
+## UNetMotionModel
+[[autodoc]] UNetMotionModel
+
+## UNet3DConditionOutput
+[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput
diff --git a/docs/source/en/api/pipelines/animatediff.md b/docs/source/en/api/pipelines/animatediff.md
new file mode 100644
index 000000000000..ff621c60221d
--- /dev/null
+++ b/docs/source/en/api/pipelines/animatediff.md
@@ -0,0 +1,108 @@
+
+
+# Text-to-Video Generation with AnimateDiff
+
+## Overview
+
+[AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725) by Yuwei Guo, Ceyuan Yang*, Anyi Rao, Yaohui Wang, Yu Qiao, Dahua Lin, Bo Dai
+
+The abstract of the paper is the following:
+
+With the advance of text-to-image models (e.g., Stable Diffusion) and corresponding personalization techniques such as DreamBooth and LoRA, everyone can manifest their imagination into high-quality images at an affordable cost. Subsequently, there is a great demand for image animation techniques to further combine generated static images with motion dynamics. In this report, we propose a practical framework to animate most of the existing personalized text-to-image models once and for all, saving efforts in model-specific tuning. At the core of the proposed framework is to insert a newly initialized motion modeling module into the frozen text-to-image model and train it on video clips to distill reasonable motion priors. Once trained, by simply injecting this motion modeling module, all personalized versions derived from the same base T2I readily become text-driven models that produce diverse and personalized animated images. We conduct our evaluation on several public representative personalized text-to-image models across anime pictures and realistic photographs, and demonstrate that our proposed framework helps these models generate temporally smooth animation clips while preserving the domain and diversity of their outputs. Code and pre-trained weights will be publicly available at this https URL .
+
+## Available Pipelines:
+
+| Pipeline | Tasks | Demo
+|---|---|:---:|
+| [AnimateDiffPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff.py) | *Text-to-Video Generation with AnimateDiff* |
+
+## Usage example
+
+AnimateDiff works with a MotionAdapter checkpoint and a Stable Diffusion model checkpoint. The MotionAdapter is a collection of Motion Modules that are responsible for adding coherent motion across image frames. These modules are applied after the Resnet and Attention blocks in Stable Diffusion UNet.
+
+The following example demonstrates how to use a *MotionAdapter* checkpoint with Diffusers for inference based on StableDiffusion-1.4/1.5.
+
+```python
+import torch
+from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
+from diffusers.utils import export_to_gif
+
+# Load the motion adapter
+adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
+# load SD 1.5 based finetuned model
+model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
+pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter)
+scheduler = DDIMScheduler.from_pretrained(
+ model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
+)
+pipe.scheduler = scheduler
+
+# enable memory savings
+pipe.enable_vae_slicing()
+pipe.enable_model_cpu_offload()
+
+output = pipe(
+ prompt=(
+ "masterpiece, bestquality, highlydetailed, ultradetailed, sunset, "
+ "orange sky, warm lighting, fishing boats, ocean waves seagulls, "
+ "rippling water, wharf, silhouette, serene atmosphere, dusk, evening glow, "
+ "golden hour, coastal landscape, seaside scenery"
+ ),
+ negative_prompt="bad quality, worse quality",
+ num_frames=16,
+ guidance_scale=7.5,
+ num_inference_steps=25,
+ generator=torch.Generator("cpu").manual_seed(42),
+)
+frames = output.frames[0]
+export_to_gif(frames, "animation.gif")
+```
+
+Here are some sample outputs:
+
+
+
+
+ masterpiece, bestquality, sunset.
+
+
+ |
+
+
+
+
+
+AnimateDiff tends to work better with finetuned Stable Diffusion models. If you plan on using a scheduler that can clip samples, make sure to disable it by setting `clip_sample=False` in the scheduler as this can also have an adverse effect on generated samples.
+
+
+
+## AnimateDiffPipeline
+[[autodoc]] AnimateDiffPipeline
+ - all
+ - __call__
+ - enable_freeu
+ - disable_freeu
+ - enable_vae_slicing
+ - disable_vae_slicing
+ - enable_vae_tiling
+ - disable_vae_tiling
+
+## AnimateDiffPipelineOutput
+
+[[autodoc]] pipelines.animatediff.AnimateDiffPipelineOutput
+
+## Available checkpoints
+
+Motion Adapter checkpoints can be found under [guoyww](https://huggingface.co/guoyww/). These checkpoints are meant to work with any model based on Stable Diffusion 1.4/1.5
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 9d146ac233c2..18266df1eadf 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -79,6 +79,7 @@
"AutoencoderTiny",
"ControlNetModel",
"ModelMixin",
+ "MotionAdapter",
"MultiAdapter",
"PriorTransformer",
"T2IAdapter",
@@ -88,6 +89,7 @@
"UNet2DConditionModel",
"UNet2DModel",
"UNet3DConditionModel",
+ "UNetMotionModel",
"VQModel",
]
)
@@ -195,6 +197,7 @@
[
"AltDiffusionImg2ImgPipeline",
"AltDiffusionPipeline",
+ "AnimateDiffPipeline",
"AudioLDM2Pipeline",
"AudioLDM2ProjectionModel",
"AudioLDM2UNet2DConditionModel",
@@ -440,6 +443,7 @@
AutoencoderTiny,
ControlNetModel,
ModelMixin,
+ MotionAdapter,
MultiAdapter,
PriorTransformer,
T2IAdapter,
@@ -449,6 +453,7 @@
UNet2DConditionModel,
UNet2DModel,
UNet3DConditionModel,
+ UNetMotionModel,
VQModel,
)
from .optimization import (
@@ -537,6 +542,7 @@
from .pipelines import (
AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline,
+ AnimateDiffPipeline,
AudioLDM2Pipeline,
AudioLDM2ProjectionModel,
AudioLDM2UNet2DConditionModel,
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index a5d0066d5c40..f807353312d1 100644
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -35,6 +35,7 @@
_import_structure["unet_2d"] = ["UNet2DModel"]
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
+ _import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["vq_model"] = ["VQModel"]
if is_flax_available():
@@ -60,6 +61,7 @@
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_condition import UNet3DConditionModel
+ from .unet_motion_model import MotionAdapter, UNetMotionModel
from .vq_model import VQModel
if is_flax_available():
diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py
index 80e2afa94a87..cb2f24a52786 100644
--- a/src/diffusers/models/attention.py
+++ b/src/diffusers/models/attention.py
@@ -20,6 +20,7 @@
from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU
from .attention_processor import Attention
+from .embeddings import SinusoidalPositionalEmbedding
from .lora import LoRACompatibleLinear
from .normalization import AdaLayerNorm, AdaLayerNormZero
@@ -96,6 +97,10 @@ class BasicTransformerBlock(nn.Module):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
+ positional_embeddings (`str`, *optional*, defaults to `None`):
+ The type of positional embeddings to apply to.
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
+ The maximum number of positional embeddings to apply.
"""
def __init__(
@@ -115,6 +120,8 @@ def __init__(
norm_type: str = "layer_norm",
final_dropout: bool = False,
attention_type: str = "default",
+ positional_embeddings: Optional[str] = None,
+ num_positional_embeddings: Optional[int] = None,
):
super().__init__()
self.only_cross_attention = only_cross_attention
@@ -128,6 +135,16 @@ def __init__(
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
+ if positional_embeddings and (num_positional_embeddings is None):
+ raise ValueError(
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
+ )
+
+ if positional_embeddings == "sinusoidal":
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
+ else:
+ self.pos_embed = None
+
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if self.use_ada_layer_norm:
@@ -207,6 +224,9 @@ def forward(
else:
norm_hidden_states = self.norm1(hidden_states)
+ if self.pos_embed is not None:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
# 1. Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
@@ -234,6 +254,8 @@ def forward(
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
+ if self.pos_embed is not None:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
norm_hidden_states,
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index d3422c8f58b2..f1128e518e2a 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -251,6 +251,33 @@ def forward(self, x):
return out
+class SinusoidalPositionalEmbedding(nn.Module):
+ """Apply positional information to a sequence of embeddings.
+
+ Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
+ them
+
+ Args:
+ embed_dim: (int): Dimension of the positional embedding.
+ max_seq_length: Maximum sequence length to apply positional embeddings
+
+ """
+
+ def __init__(self, embed_dim: int, max_seq_length: int = 32):
+ super().__init__()
+ position = torch.arange(max_seq_length).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
+ pe = torch.zeros(1, max_seq_length, embed_dim)
+ pe[0, :, 0::2] = torch.sin(position * div_term)
+ pe[0, :, 1::2] = torch.cos(position * div_term)
+ self.register_buffer("pe", pe)
+
+ def forward(self, x):
+ _, seq_length, _ = x.shape
+ x = x + self.pe[:, :seq_length]
+ return x
+
+
class ImagePositionalEmbeddings(nn.Module):
"""
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py
index 55c9e6968a32..2e053d70eaa7 100644
--- a/src/diffusers/models/transformer_temporal.py
+++ b/src/diffusers/models/transformer_temporal.py
@@ -59,6 +59,10 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
double_self_attention (`bool`, *optional*):
Configure if each `TransformerBlock` should contain two self-attention layers.
+ positional_embeddings: (`str`, *optional*):
+ The type of positional embeddings to apply to the sequence input before passing use.
+ num_positional_embeddings: (`int`, *optional*):
+ The maximum length of the sequence over which to apply positional embeddings.
"""
@register_to_config
@@ -77,6 +81,8 @@ def __init__(
activation_fn: str = "geglu",
norm_elementwise_affine: bool = True,
double_self_attention: bool = True,
+ positional_embeddings: Optional[str] = None,
+ num_positional_embeddings: Optional[int] = None,
):
super().__init__()
self.num_attention_heads = num_attention_heads
@@ -101,6 +107,8 @@ def __init__(
attention_bias=attention_bias,
double_self_attention=double_self_attention,
norm_elementwise_affine=norm_elementwise_affine,
+ positional_embeddings=positional_embeddings,
+ num_positional_embeddings=num_positional_embeddings,
)
for d in range(num_layers)
]
diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py
index 180ae0dc1a81..e8e42cf5615f 100644
--- a/src/diffusers/models/unet_3d_blocks.py
+++ b/src/diffusers/models/unet_3d_blocks.py
@@ -12,10 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict, Optional, Tuple
+
import torch
from torch import nn
+from ..utils import is_torch_version
from ..utils.torch_utils import apply_freeu
+from .dual_transformer_2d import DualTransformer2DModel
from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
from .transformer_2d import Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
@@ -39,6 +43,8 @@ def get_down_block(
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
+ temporal_num_attention_heads=8,
+ temporal_max_seq_length=32,
):
if down_block_type == "DownBlock3D":
return DownBlock3D(
@@ -74,6 +80,45 @@ def get_down_block(
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
+ if down_block_type == "DownBlockMotion":
+ return DownBlockMotion(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ temporal_num_attention_heads=temporal_num_attention_heads,
+ temporal_max_seq_length=temporal_max_seq_length,
+ )
+ elif down_block_type == "CrossAttnDownBlockMotion":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion")
+ return CrossAttnDownBlockMotion(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ temporal_num_attention_heads=temporal_num_attention_heads,
+ temporal_max_seq_length=temporal_max_seq_length,
+ )
+
raise ValueError(f"{down_block_type} does not exist.")
@@ -96,6 +141,9 @@ def get_up_block(
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
+ temporal_num_attention_heads=8,
+ temporal_cross_attention_dim=None,
+ temporal_max_seq_length=32,
):
if up_block_type == "UpBlock3D":
return UpBlock3D(
@@ -133,6 +181,46 @@ def get_up_block(
resnet_time_scale_shift=resnet_time_scale_shift,
resolution_idx=resolution_idx,
)
+ if up_block_type == "UpBlockMotion":
+ return UpBlockMotion(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resolution_idx=resolution_idx,
+ temporal_num_attention_heads=temporal_num_attention_heads,
+ temporal_max_seq_length=temporal_max_seq_length,
+ )
+ elif up_block_type == "CrossAttnUpBlockMotion":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion")
+ return CrossAttnUpBlockMotion(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resolution_idx=resolution_idx,
+ temporal_num_attention_heads=temporal_num_attention_heads,
+ temporal_max_seq_length=temporal_max_seq_length,
+ )
raise ValueError(f"{up_block_type} does not exist.")
@@ -724,3 +812,800 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_si
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
+
+
+class DownBlockMotion(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ temporal_num_attention_heads=1,
+ temporal_cross_attention_dim=None,
+ temporal_max_seq_length=32,
+ ):
+ super().__init__()
+ resnets = []
+ motion_modules = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ motion_modules.append(
+ TransformerTemporalModel(
+ num_attention_heads=temporal_num_attention_heads,
+ in_channels=out_channels,
+ norm_num_groups=resnet_groups,
+ cross_attention_dim=temporal_cross_attention_dim,
+ attention_bias=False,
+ activation_fn="geglu",
+ positional_embeddings="sinusoidal",
+ num_positional_embeddings=temporal_max_seq_length,
+ attention_head_dim=out_channels // temporal_num_attention_heads,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.motion_modules = nn.ModuleList(motion_modules)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None, scale: float = 1.0, num_frames=1):
+ output_states = ()
+
+ blocks = zip(self.resnets, self.motion_modules)
+ for resnet, motion_module in blocks:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, scale
+ )
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, num_frames
+ )
+
+ else:
+ hidden_states = resnet(hidden_states, temb, scale=scale)
+ hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states, scale=scale)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnDownBlockMotion(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ attention_type="default",
+ temporal_cross_attention_dim=None,
+ temporal_num_attention_heads=8,
+ temporal_max_seq_length=32,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+ motion_modules = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ attention_type=attention_type,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+
+ motion_modules.append(
+ TransformerTemporalModel(
+ num_attention_heads=temporal_num_attention_heads,
+ in_channels=out_channels,
+ norm_num_groups=resnet_groups,
+ cross_attention_dim=temporal_cross_attention_dim,
+ attention_bias=False,
+ activation_fn="geglu",
+ positional_embeddings="sinusoidal",
+ num_positional_embeddings=temporal_max_seq_length,
+ attention_head_dim=out_channels // temporal_num_attention_heads,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+ self.motion_modules = nn.ModuleList(motion_modules)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ temb=None,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ num_frames=1,
+ encoder_attention_mask=None,
+ cross_attention_kwargs=None,
+ additional_residuals=None,
+ ):
+ output_states = ()
+
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+
+ blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
+ for i, (resnet, attn, motion_module) in enumerate(blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ hidden_states = motion_module(
+ hidden_states,
+ num_frames=num_frames,
+ )[0]
+
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
+ if i == len(blocks) - 1 and additional_residuals is not None:
+ hidden_states = hidden_states + additional_residuals
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states, scale=lora_scale)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnUpBlockMotion(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ resolution_idx: int = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ attention_type="default",
+ temporal_cross_attention_dim=None,
+ temporal_num_attention_heads=8,
+ temporal_max_seq_length=32,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+ motion_modules = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ attention_type=attention_type,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ motion_modules.append(
+ TransformerTemporalModel(
+ num_attention_heads=temporal_num_attention_heads,
+ in_channels=out_channels,
+ norm_num_groups=resnet_groups,
+ cross_attention_dim=temporal_cross_attention_dim,
+ attention_bias=False,
+ activation_fn="geglu",
+ positional_embeddings="sinusoidal",
+ num_positional_embeddings=temporal_max_seq_length,
+ attention_head_dim=out_channels // temporal_num_attention_heads,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+ self.motion_modules = nn.ModuleList(motion_modules)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ num_frames=1,
+ ):
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
+
+ blocks = zip(self.resnets, self.attentions, self.motion_modules)
+ for resnet, attn, motion_module in blocks:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ # FreeU: Only operate on the first two stages
+ if is_freeu_enabled:
+ hidden_states, res_hidden_states = apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_hidden_states,
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ hidden_states = motion_module(
+ hidden_states,
+ num_frames=num_frames,
+ )[0]
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
+
+ return hidden_states
+
+
+class UpBlockMotion(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: int = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ temporal_norm_num_groups=32,
+ temporal_cross_attention_dim=None,
+ temporal_num_attention_heads=8,
+ temporal_max_seq_length=32,
+ ):
+ super().__init__()
+ resnets = []
+ motion_modules = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ motion_modules.append(
+ TransformerTemporalModel(
+ num_attention_heads=temporal_num_attention_heads,
+ in_channels=out_channels,
+ norm_num_groups=temporal_norm_num_groups,
+ cross_attention_dim=temporal_cross_attention_dim,
+ attention_bias=False,
+ activation_fn="geglu",
+ positional_embeddings="sinusoidal",
+ num_positional_embeddings=temporal_max_seq_length,
+ attention_head_dim=out_channels // temporal_num_attention_heads,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.motion_modules = nn.ModuleList(motion_modules)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0, num_frames=1
+ ):
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
+
+ blocks = zip(self.resnets, self.motion_modules)
+
+ for resnet, motion_module in blocks:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ # FreeU: Only operate on the first two stages
+ if is_freeu_enabled:
+ hidden_states, res_hidden_states = apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_hidden_states,
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ )
+
+ else:
+ hidden_states = resnet(hidden_states, temb, scale=scale)
+ hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
+
+ return hidden_states
+
+
+class UNetMidBlockCrossAttnMotion(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads=1,
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ upcast_attention=False,
+ attention_type="default",
+ temporal_num_attention_heads=1,
+ temporal_cross_attention_dim=None,
+ temporal_max_seq_length=32,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+ motion_modules = []
+
+ for _ in range(num_layers):
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ in_channels // num_attention_heads,
+ in_channels=in_channels,
+ num_layers=transformer_layers_per_block,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ attention_type=attention_type,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ num_attention_heads,
+ in_channels // num_attention_heads,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ motion_modules.append(
+ TransformerTemporalModel(
+ num_attention_heads=temporal_num_attention_heads,
+ attention_head_dim=in_channels // temporal_num_attention_heads,
+ in_channels=in_channels,
+ norm_num_groups=resnet_groups,
+ cross_attention_dim=temporal_cross_attention_dim,
+ attention_bias=False,
+ positional_embeddings="sinusoidal",
+ num_positional_embeddings=temporal_max_seq_length,
+ activation_fn="geglu",
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+ self.motion_modules = nn.ModuleList(motion_modules)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ num_frames=1,
+ ) -> torch.FloatTensor:
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
+
+ blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
+ for attn, resnet, motion_module in blocks:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(motion_module),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ hidden_states = motion_module(
+ hidden_states,
+ num_frames=num_frames,
+ )[0]
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+
+ return hidden_states
diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py
new file mode 100644
index 000000000000..5d528a34ec96
--- /dev/null
+++ b/src/diffusers/models/unet_motion_model.py
@@ -0,0 +1,874 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..loaders import UNet2DConditionLoadersMixin
+from ..utils import logging
+from .attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from .embeddings import TimestepEmbedding, Timesteps
+from .modeling_utils import ModelMixin
+from .transformer_temporal import TransformerTemporalModel
+from .unet_2d_blocks import UNetMidBlock2DCrossAttn
+from .unet_2d_condition import UNet2DConditionModel
+from .unet_3d_blocks import (
+ CrossAttnDownBlockMotion,
+ CrossAttnUpBlockMotion,
+ DownBlockMotion,
+ UNetMidBlockCrossAttnMotion,
+ UpBlockMotion,
+ get_down_block,
+ get_up_block,
+)
+from .unet_3d_condition import UNet3DConditionOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class MotionModules(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ layers_per_block=2,
+ num_attention_heads=8,
+ attention_bias=False,
+ cross_attention_dim=None,
+ activation_fn="geglu",
+ norm_num_groups=32,
+ max_seq_length=32,
+ ):
+ super().__init__()
+ self.motion_modules = nn.ModuleList([])
+
+ for i in range(layers_per_block):
+ self.motion_modules.append(
+ TransformerTemporalModel(
+ in_channels=in_channels,
+ norm_num_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=in_channels // num_attention_heads,
+ positional_embeddings="sinusoidal",
+ num_positional_embeddings=max_seq_length,
+ )
+ )
+
+
+class MotionAdapter(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ block_out_channels=(320, 640, 1280, 1280),
+ motion_layers_per_block=2,
+ motion_mid_block_layers_per_block=1,
+ motion_num_attention_heads=8,
+ motion_norm_num_groups=32,
+ motion_max_seq_length=32,
+ use_motion_mid_block=True,
+ ):
+ """Container to store AnimateDiff Motion Modules
+
+ Args:
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each UNet block.
+ motion_layers_per_block (`int`, *optional*, defaults to 2):
+ The number of motion layers per UNet block.
+ motion_mid_block_layers_per_block (`int`, *optional*, defaults to 1):
+ The number of motion layers in the middle UNet block.
+ motion_num_attention_heads (`int`, *optional*, defaults to 8):
+ The number of heads to use in each attention layer of the motion module.
+ motion_norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use in each group normalization layer of the motion module.
+ motion_max_seq_length (`int`, *optional*, defaults to 32):
+ The maximum sequence length to use in the motion module.
+ use_motion_mid_block (`bool`, *optional*, defaults to True):
+ Whether to use a motion module in the middle of the UNet.
+ """
+
+ super().__init__()
+ down_blocks = []
+ up_blocks = []
+
+ for i, channel in enumerate(block_out_channels):
+ output_channel = block_out_channels[i]
+ down_blocks.append(
+ MotionModules(
+ in_channels=output_channel,
+ norm_num_groups=motion_norm_num_groups,
+ cross_attention_dim=None,
+ activation_fn="geglu",
+ attention_bias=False,
+ num_attention_heads=motion_num_attention_heads,
+ max_seq_length=motion_max_seq_length,
+ layers_per_block=motion_layers_per_block,
+ )
+ )
+
+ if use_motion_mid_block:
+ self.mid_block = MotionModules(
+ in_channels=block_out_channels[-1],
+ norm_num_groups=motion_norm_num_groups,
+ cross_attention_dim=None,
+ activation_fn="geglu",
+ attention_bias=False,
+ num_attention_heads=motion_num_attention_heads,
+ layers_per_block=motion_mid_block_layers_per_block,
+ max_seq_length=motion_max_seq_length,
+ )
+ else:
+ self.mid_block = None
+
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, channel in enumerate(reversed_block_out_channels):
+ output_channel = reversed_block_out_channels[i]
+ up_blocks.append(
+ MotionModules(
+ in_channels=output_channel,
+ norm_num_groups=motion_norm_num_groups,
+ cross_attention_dim=None,
+ activation_fn="geglu",
+ attention_bias=False,
+ num_attention_heads=motion_num_attention_heads,
+ max_seq_length=motion_max_seq_length,
+ layers_per_block=motion_layers_per_block + 1,
+ )
+ )
+
+ self.down_blocks = nn.ModuleList(down_blocks)
+ self.up_blocks = nn.ModuleList(up_blocks)
+
+ def forward(self, sample):
+ pass
+
+
+class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
+ r"""
+ A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a
+ sample shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+ """
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlockMotion",
+ "CrossAttnDownBlockMotion",
+ "CrossAttnDownBlockMotion",
+ "DownBlockMotion",
+ ),
+ up_block_types: Tuple[str] = (
+ "UpBlockMotion",
+ "CrossAttnUpBlockMotion",
+ "CrossAttnUpBlockMotion",
+ "CrossAttnUpBlockMotion",
+ ),
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ use_linear_projection: bool = False,
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = 8,
+ motion_max_seq_length: Optional[int] = 32,
+ motion_num_attention_heads: int = 8,
+ use_motion_mid_block: int = True,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ # input
+ conv_in_kernel = 3
+ conv_out_kernel = 3
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ time_embed_dim = block_out_channels[0] * 4
+ self.time_proj = Timesteps(block_out_channels[0], True, 0)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ )
+
+ # class embedding
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[i],
+ downsample_padding=downsample_padding,
+ use_linear_projection=use_linear_projection,
+ dual_cross_attention=False,
+ temporal_num_attention_heads=motion_num_attention_heads,
+ temporal_max_seq_length=motion_max_seq_length,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ if use_motion_mid_block:
+ self.mid_block = UNetMidBlockCrossAttnMotion(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=False,
+ temporal_num_attention_heads=motion_num_attention_heads,
+ temporal_max_seq_length=motion_max_seq_length,
+ )
+
+ else:
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=False,
+ )
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=reversed_num_attention_heads[i],
+ dual_cross_attention=False,
+ resolution_idx=i,
+ use_linear_projection=use_linear_projection,
+ temporal_num_attention_heads=motion_num_attention_heads,
+ temporal_max_seq_length=motion_max_seq_length,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ if norm_num_groups is not None:
+ self.conv_norm_out = nn.GroupNorm(
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
+ )
+ self.conv_act = nn.SiLU()
+ else:
+ self.conv_norm_out = None
+ self.conv_act = None
+
+ conv_out_padding = (conv_out_kernel - 1) // 2
+ self.conv_out = nn.Conv2d(
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
+ )
+
+ @classmethod
+ def from_unet2d(
+ cls,
+ unet: UNet2DConditionModel,
+ motion_adapter: Optional[MotionAdapter] = None,
+ load_weights: bool = True,
+ ):
+ has_motion_adapter = motion_adapter is not None
+
+ # based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459
+ config = unet.config
+ config["_class_name"] = cls.__name__
+
+ down_blocks = []
+ for down_blocks_type in config["down_block_types"]:
+ if "CrossAttn" in down_blocks_type:
+ down_blocks.append("CrossAttnDownBlockMotion")
+ else:
+ down_blocks.append("DownBlockMotion")
+ config["down_block_types"] = down_blocks
+
+ up_blocks = []
+ for down_blocks_type in config["up_block_types"]:
+ if "CrossAttn" in down_blocks_type:
+ up_blocks.append("CrossAttnUpBlockMotion")
+ else:
+ up_blocks.append("UpBlockMotion")
+
+ config["up_block_types"] = up_blocks
+
+ if has_motion_adapter:
+ config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"]
+ config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"]
+ config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"]
+
+ # Need this for backwards compatibility with UNet2DConditionModel checkpoints
+ if not config.get("num_attention_heads"):
+ config["num_attention_heads"] = config["attention_head_dim"]
+
+ model = cls.from_config(config)
+
+ if not load_weights:
+ return model
+
+ model.conv_in.load_state_dict(unet.conv_in.state_dict())
+ model.time_proj.load_state_dict(unet.time_proj.state_dict())
+ model.time_embedding.load_state_dict(unet.time_embedding.state_dict())
+
+ for i, down_block in enumerate(unet.down_blocks):
+ model.down_blocks[i].resnets.load_state_dict(down_block.resnets.state_dict())
+ if hasattr(model.down_blocks[i], "attentions"):
+ model.down_blocks[i].attentions.load_state_dict(down_block.attentions.state_dict())
+ if model.down_blocks[i].downsamplers:
+ model.down_blocks[i].downsamplers.load_state_dict(down_block.downsamplers.state_dict())
+
+ for i, up_block in enumerate(unet.up_blocks):
+ model.up_blocks[i].resnets.load_state_dict(up_block.resnets.state_dict())
+ if hasattr(model.up_blocks[i], "attentions"):
+ model.up_blocks[i].attentions.load_state_dict(up_block.attentions.state_dict())
+ if model.up_blocks[i].upsamplers:
+ model.up_blocks[i].upsamplers.load_state_dict(up_block.upsamplers.state_dict())
+
+ model.mid_block.resnets.load_state_dict(unet.mid_block.resnets.state_dict())
+ model.mid_block.attentions.load_state_dict(unet.mid_block.attentions.state_dict())
+
+ if unet.conv_norm_out is not None:
+ model.conv_norm_out.load_state_dict(unet.conv_norm_out.state_dict())
+ if unet.conv_act is not None:
+ model.conv_act.load_state_dict(unet.conv_act.state_dict())
+ model.conv_out.load_state_dict(unet.conv_out.state_dict())
+
+ if has_motion_adapter:
+ model.load_motion_modules(motion_adapter)
+
+ # ensure that the Motion UNet is the same dtype as the UNet2DConditionModel
+ model.to(unet.dtype)
+
+ return model
+
+ def freeze_unet2d_params(self):
+ """Freeze the weights of just the UNet2DConditionModel, and leave the motion modules
+ unfrozen for fine tuning.
+ """
+ # Freeze everything
+ for param in self.parameters():
+ param.requires_grad = False
+
+ # Unfreeze Motion Modules
+ for down_block in self.down_blocks:
+ motion_modules = down_block.motion_modules
+ for param in motion_modules.parameters():
+ param.requires_grad = True
+
+ for up_block in self.up_blocks:
+ motion_modules = up_block.motion_modules
+ for param in motion_modules.parameters():
+ param.requires_grad = True
+
+ if hasattr(self.mid_block, "motion_modules"):
+ motion_modules = self.mid_block.motion_modules
+ for param in motion_modules.parameters():
+ param.requires_grad = True
+
+ return
+
+ def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]):
+ for i, down_block in enumerate(motion_adapter.down_blocks):
+ self.down_blocks[i].motion_modules.load_state_dict(down_block.motion_modules.state_dict())
+ for i, up_block in enumerate(motion_adapter.up_blocks):
+ self.up_blocks[i].motion_modules.load_state_dict(up_block.motion_modules.state_dict())
+
+ # to support older motion modules that don't have a mid_block
+ if hasattr(self.mid_block, "motion_modules"):
+ self.mid_block.motion_modules.load_state_dict(motion_adapter.mid_block.motion_modules.state_dict())
+
+ def save_motion_modules(
+ self,
+ save_directory: str,
+ is_main_process: bool = True,
+ safe_serialization: bool = True,
+ variant: Optional[str] = None,
+ push_to_hub: bool = False,
+ **kwargs,
+ ):
+ state_dict = self.state_dict()
+
+ # Extract all motion modules
+ motion_state_dict = {}
+ for k, v in state_dict.items():
+ if "motion_modules" in k:
+ motion_state_dict[k] = v
+
+ adapter = MotionAdapter(
+ block_out_channels=self.config["block_out_channels"],
+ motion_layers_per_block=self.config["layers_per_block"],
+ motion_norm_num_groups=self.config["norm_num_groups"],
+ motion_num_attention_heads=self.config["motion_num_attention_heads"],
+ motion_max_seq_length=self.config["motion_max_seq_length"],
+ use_motion_mid_block=self.config["use_motion_mid_block"],
+ )
+ adapter.load_state_dict(motion_state_dict)
+ adapter.save_pretrained(
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ safe_serialization=safe_serialization,
+ variant=variant,
+ push_to_hub=push_to_hub,
+ **kwargs,
+ )
+
+ @property
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
+ ):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor, _remove_lora=_remove_lora)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
+ def enable_forward_chunking(self, chunk_size=None, dim=0):
+ """
+ Sets the attention processor to use [feed forward
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
+
+ Parameters:
+ chunk_size (`int`, *optional*):
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
+ over each tensor of dim=`dim`.
+ dim (`int`, *optional*, defaults to `0`):
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
+ or dim=1 (sequence length).
+ """
+ if dim not in [0, 1]:
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
+
+ # By default chunk size is 1
+ chunk_size = chunk_size or 1
+
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
+ if hasattr(module, "set_chunk_feed_forward"):
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
+
+ for child in module.children():
+ fn_recursive_feed_forward(child, chunk_size, dim)
+
+ for module in self.children():
+ fn_recursive_feed_forward(module, chunk_size, dim)
+
+ # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
+ def disable_forward_chunking(self):
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
+ if hasattr(module, "set_chunk_feed_forward"):
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
+
+ for child in module.children():
+ fn_recursive_feed_forward(child, chunk_size, dim)
+
+ for module in self.children():
+ fn_recursive_feed_forward(module, None, 0)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor, _remove_lora=True)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
+ module.gradient_checkpointing = value
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu
+ def enable_freeu(self, s1, s2, b1, b2):
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ for i, upsample_block in enumerate(self.up_blocks):
+ setattr(upsample_block, "s1", s1)
+ setattr(upsample_block, "s2", s2)
+ setattr(upsample_block, "b1", b1)
+ setattr(upsample_block, "b2", b2)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu
+ def disable_freeu(self):
+ """Disables the FreeU mechanism."""
+ freeu_keys = {"s1", "s2", "b1", "b2"}
+ for i, upsample_block in enumerate(self.up_blocks):
+ for k in freeu_keys:
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
+ setattr(upsample_block, k, None)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNet3DConditionOutput, Tuple]:
+ r"""
+ The [`UNetMotionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
+ A tensor that if specified is added to the residual of the middle unet block.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise
+ a `tuple` is returned where the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ num_frames = sample.shape[2]
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=self.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+ emb = emb.repeat_interleave(repeats=num_frames, dim=0)
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
+
+ # 2. pre-process
+ sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
+ sample = self.conv_in(sample)
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ num_frames=num_frames,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
+
+ down_block_res_samples += res_samples
+
+ if down_block_additional_residuals is not None:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples += (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ # To support older versions of motion modules that don't have a mid_block
+ if hasattr(self.mid_block, "motion_modules"):
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ num_frames=num_frames,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+
+ if mid_block_additional_residual is not None:
+ sample = sample + mid_block_additional_residual
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ num_frames=num_frames,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ upsample_size=upsample_size,
+ num_frames=num_frames,
+ )
+
+ # 6. post-process
+ if self.conv_norm_out:
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+
+ sample = self.conv_out(sample)
+
+ # reshape to (batch, channel, framerate, width, height)
+ sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet3DConditionOutput(sample=sample)
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index df7a89fc1b81..9c69706560ca 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -62,6 +62,7 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["alt_diffusion"] = ["AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline"]
+ _import_structure["animatediff"] = ["AnimateDiffPipeline"]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
"AudioLDM2Pipeline",
@@ -291,6 +292,7 @@
from ..utils.dummy_torch_and_transformers_objects import *
else:
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
+ from .animatediff import AnimateDiffPipeline
from .audioldm import AudioLDMPipeline
from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
from .blip_diffusion import BlipDiffusionPipeline
diff --git a/src/diffusers/pipelines/animatediff/__init__.py b/src/diffusers/pipelines/animatediff/__init__.py
new file mode 100644
index 000000000000..503352fec865
--- /dev/null
+++ b/src/diffusers/pipelines/animatediff/__init__.py
@@ -0,0 +1,46 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline", "AnimateDiffPipelineOutput"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+
+ else:
+ from .pipeline_animatediff import AnimateDiffPipeline, AnimateDiffPipelineOutput
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
new file mode 100644
index 000000000000..650b447cd23a
--- /dev/null
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
@@ -0,0 +1,694 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet2DConditionModel, UNetMotionModel
+from ...models.lora import adjust_lora_scale_text_encoder
+from ...models.unet_motion_model import MotionAdapter
+from ...schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
+ >>> from diffusers.utils import export_to_gif
+
+ >>> adapter = MotionAdapter.from_pretrained("diffusers/motion-adapter")
+ >>> pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter)
+ >>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False)
+ >>> output = pipe(prompt="A corgi walking in the park")
+ >>> frames = output.frames[0]
+ >>> export_to_gif(frames, "animation.gif")
+ ```
+"""
+
+
+def tensor2vid(video: torch.Tensor, processor, output_type="np"):
+ # Based on:
+ # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
+
+ batch_size, channels, num_frames, height, width = video.shape
+ outputs = []
+ for batch_idx in range(batch_size):
+ batch_vid = video[batch_idx].permute(1, 0, 2, 3)
+ batch_output = processor.postprocess(batch_vid, output_type)
+
+ outputs.append(batch_output)
+
+ return outputs
+
+
+@dataclass
+class AnimateDiffPipelineOutput(BaseOutput):
+ frames: Union[torch.Tensor, np.ndarray]
+
+
+class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
+ r"""
+ Pipeline for text-to-video generation.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer (`CLIPTokenizer`):
+ A [`~transformers.CLIPTokenizer`] to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
+ motion_adapter ([`MotionAdapter`]):
+ A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+ model_cpu_offload_seq = "text_encoder->unet->vae"
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ motion_adapter: MotionAdapter,
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ ):
+ super().__init__()
+ unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ motion_adapter=motion_adapter,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+
+ batch_size, channels, num_frames, height, width = latents.shape
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
+
+ image = self.vae.decode(latents).sample
+ video = (
+ image[None, :]
+ .reshape(
+ (
+ batch_size,
+ num_frames,
+ -1,
+ )
+ + image.shape[2:]
+ )
+ .permute(0, 2, 1, 3, 4)
+ )
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ video = video.float()
+ return video
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stages where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ if not hasattr(self, "unet"):
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
+ def disable_freeu(self):
+ """Disables the FreeU mechanism if enabled."""
+ self.unet.disable_freeu()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
+ def prepare_latents(
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_frames,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ num_frames: Optional[int] = 16,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated video.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated video.
+ num_frames (`int`, *optional*, defaults to 16):
+ The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
+ amounts to 2 seconds of video.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
+ `(batch_size, num_channel, num_frames, height, width)`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
+ `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
+ of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ Examples:
+
+ Returns:
+ [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_videos_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=clip_skip,
+ )
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ num_frames,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if output_type == "latent":
+ return AnimateDiffPipelineOutput(frames=latents)
+
+ # Post-processing
+ video_tensor = self.decode_latents(latents)
+
+ if output_type == "pt":
+ video = video_tensor
+ else:
+ video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return AnimateDiffPipelineOutput(frames=video)
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 890f836c73c6..d6d74a89cafb 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -77,6 +77,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class MotionAdapter(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class MultiAdapter(metaclass=DummyObject):
_backends = ["torch"]
@@ -212,6 +227,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class UNetMotionModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class VQModel(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 3b5e3ad4e07d..2fd80f321e6b 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -32,6 +32,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class AnimateDiffPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class AudioLDM2Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/tests/models/test_models_unet_motion.py b/tests/models/test_models_unet_motion.py
new file mode 100644
index 000000000000..60c3399db537
--- /dev/null
+++ b/tests/models/test_models_unet_motion.py
@@ -0,0 +1,306 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import os
+import tempfile
+import unittest
+
+import numpy as np
+import torch
+
+from diffusers import MotionAdapter, UNet2DConditionModel, UNetMotionModel
+from diffusers.utils import logging
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ floats_tensor,
+ torch_device,
+)
+
+from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
+
+
+logger = logging.get_logger(__name__)
+
+enable_full_determinism()
+
+
+class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+ model_class = UNetMotionModel
+ main_input_name = "sample"
+
+ @property
+ def dummy_input(self):
+ batch_size = 4
+ num_channels = 4
+ num_frames = 8
+ sizes = (32, 32)
+
+ noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
+ time_step = torch.tensor([10]).to(torch_device)
+ encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
+
+ return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
+
+ @property
+ def input_shape(self):
+ return (4, 8, 32, 32)
+
+ @property
+ def output_shape(self):
+ return (4, 8, 32, 32)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "block_out_channels": (32, 64),
+ "down_block_types": ("CrossAttnDownBlockMotion", "DownBlockMotion"),
+ "up_block_types": ("UpBlockMotion", "CrossAttnUpBlockMotion"),
+ "cross_attention_dim": 32,
+ "num_attention_heads": 4,
+ "out_channels": 4,
+ "in_channels": 4,
+ "layers_per_block": 1,
+ "sample_size": 32,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_from_unet2d(self):
+ torch.manual_seed(0)
+ unet2d = UNet2DConditionModel()
+
+ torch.manual_seed(1)
+ model = self.model_class.from_unet2d(unet2d)
+ model_state_dict = model.state_dict()
+
+ for param_name, param_value in unet2d.named_parameters():
+ self.assertTrue(torch.equal(model_state_dict[param_name], param_value))
+
+ def test_freeze_unet2d(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+ model.freeze_unet2d_params()
+
+ for param_name, param_value in model.named_parameters():
+ if "motion_modules" not in param_name:
+ self.assertFalse(param_value.requires_grad)
+
+ else:
+ self.assertTrue(param_value.requires_grad)
+
+ def test_loading_motion_adapter(self):
+ model = self.model_class()
+ adapter = MotionAdapter()
+ model.load_motion_modules(adapter)
+
+ for idx, down_block in enumerate(model.down_blocks):
+ adapter_state_dict = adapter.down_blocks[idx].motion_modules.state_dict()
+ for param_name, param_value in down_block.motion_modules.named_parameters():
+ self.assertTrue(torch.equal(adapter_state_dict[param_name], param_value))
+
+ for idx, up_block in enumerate(model.up_blocks):
+ adapter_state_dict = adapter.up_blocks[idx].motion_modules.state_dict()
+ for param_name, param_value in up_block.motion_modules.named_parameters():
+ self.assertTrue(torch.equal(adapter_state_dict[param_name], param_value))
+
+ mid_block_adapter_state_dict = adapter.mid_block.motion_modules.state_dict()
+ for param_name, param_value in model.mid_block.motion_modules.named_parameters():
+ self.assertTrue(torch.equal(mid_block_adapter_state_dict[param_name], param_value))
+
+ def test_saving_motion_modules(self):
+ torch.manual_seed(0)
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_motion_modules(tmpdirname)
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "diffusion_pytorch_model.safetensors")))
+
+ adapter_loaded = MotionAdapter.from_pretrained(tmpdirname)
+
+ torch.manual_seed(0)
+ model_loaded = self.model_class(**init_dict)
+ model_loaded.load_motion_modules(adapter_loaded)
+ model_loaded.to(torch_device)
+
+ with torch.no_grad():
+ output = model(**inputs_dict)[0]
+ output_loaded = model_loaded(**inputs_dict)[0]
+
+ max_diff = (output - output_loaded).abs().max().item()
+ self.assertLessEqual(max_diff, 1e-4, "Models give different forward passes")
+
+ @unittest.skipIf(
+ torch_device != "cuda" or not is_xformers_available(),
+ reason="XFormers attention is only available with CUDA and `xformers` installed",
+ )
+ def test_xformers_enable_works(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+
+ model.enable_xformers_memory_efficient_attention()
+
+ assert (
+ model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
+ == "XFormersAttnProcessor"
+ ), "xformers is not enabled"
+
+ def test_gradient_checkpointing_is_applied(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model_class_copy = copy.copy(self.model_class)
+
+ modules_with_gc_enabled = {}
+
+ # now monkey patch the following function:
+ # def _set_gradient_checkpointing(self, module, value=False):
+ # if hasattr(module, "gradient_checkpointing"):
+ # module.gradient_checkpointing = value
+
+ def _set_gradient_checkpointing_new(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+ modules_with_gc_enabled[module.__class__.__name__] = True
+
+ model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
+
+ model = model_class_copy(**init_dict)
+ model.enable_gradient_checkpointing()
+
+ EXPECTED_SET = {
+ "CrossAttnUpBlockMotion",
+ "CrossAttnDownBlockMotion",
+ "UNetMidBlockCrossAttnMotion",
+ "UpBlockMotion",
+ "Transformer2DModel",
+ "DownBlockMotion",
+ }
+
+ assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
+ assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
+
+ def test_feed_forward_chunking(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ init_dict["norm_num_groups"] = 32
+
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ output = model(**inputs_dict)[0]
+
+ model.enable_forward_chunking()
+ with torch.no_grad():
+ output_2 = model(**inputs_dict)[0]
+
+ self.assertEqual(output.shape, output_2.shape, "Shape doesn't match")
+ assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2
+
+ def test_pickle(self):
+ # enable deterministic behavior for gradient checkpointing
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+
+ with torch.no_grad():
+ sample = model(**inputs_dict).sample
+
+ sample_copy = copy.copy(sample)
+
+ assert (sample - sample_copy).abs().max() < 1e-4
+
+ def test_from_save_pretrained(self, expected_max_diff=5e-5):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname, safe_serialization=False)
+ torch.manual_seed(0)
+ new_model = self.model_class.from_pretrained(tmpdirname)
+ new_model.to(torch_device)
+
+ with torch.no_grad():
+ image = model(**inputs_dict)
+ if isinstance(image, dict):
+ image = image.to_tuple()[0]
+
+ new_image = new_model(**inputs_dict)
+
+ if isinstance(new_image, dict):
+ new_image = new_image.to_tuple()[0]
+
+ max_diff = (image - new_image).abs().max().item()
+ self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes")
+
+ def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname, variant="fp16", safe_serialization=False)
+
+ torch.manual_seed(0)
+ new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
+ # non-variant cannot be loaded
+ with self.assertRaises(OSError) as error_context:
+ self.model_class.from_pretrained(tmpdirname)
+
+ # make sure that error message states what keys are missing
+ assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(error_context.exception)
+
+ new_model.to(torch_device)
+
+ with torch.no_grad():
+ image = model(**inputs_dict)
+ if isinstance(image, dict):
+ image = image.to_tuple()[0]
+
+ new_image = new_model(**inputs_dict)
+
+ if isinstance(new_image, dict):
+ new_image = new_image.to_tuple()[0]
+
+ max_diff = (image - new_image).abs().max().item()
+ self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes")
+
+ def test_forward_with_norm_groups(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ init_dict["norm_num_groups"] = 16
+ init_dict["block_out_channels"] = (16, 32)
+
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ output = model(**inputs_dict)
+
+ if isinstance(output, dict):
+ output = output.to_tuple()[0]
+
+ self.assertIsNotNone(output)
+ expected_shape = inputs_dict["sample"].shape
+ self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
diff --git a/tests/pipelines/animatediff/__init__.py b/tests/pipelines/animatediff/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py
new file mode 100644
index 000000000000..baba8ba4d655
--- /dev/null
+++ b/tests/pipelines/animatediff/test_animatediff.py
@@ -0,0 +1,279 @@
+import gc
+import unittest
+
+import numpy as np
+import torch
+from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
+
+import diffusers
+from diffusers import (
+ AnimateDiffPipeline,
+ AutoencoderKL,
+ DDIMScheduler,
+ MotionAdapter,
+ UNet2DConditionModel,
+ UNetMotionModel,
+)
+from diffusers.utils import logging
+from diffusers.utils.testing_utils import numpy_cosine_similarity_distance, require_torch_gpu, slow, torch_device
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+def to_np(tensor):
+ if isinstance(tensor, torch.Tensor):
+ tensor = tensor.detach().cpu().numpy()
+
+ return tensor
+
+
+class AnimateDiffPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = AnimateDiffPipeline
+ params = TEXT_TO_IMAGE_PARAMS
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback",
+ "callback_steps",
+ ]
+ )
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ unet = UNet2DConditionModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ sample_size=32,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ cross_attention_dim=32,
+ norm_num_groups=2,
+ )
+ scheduler = DDIMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="linear",
+ clip_sample=False,
+ )
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ )
+ torch.manual_seed(0)
+ text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ )
+ text_encoder = CLIPTextModel(text_encoder_config)
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ motion_adapter = MotionAdapter(
+ block_out_channels=(32, 64),
+ motion_layers_per_block=2,
+ motion_norm_num_groups=2,
+ motion_num_attention_heads=4,
+ )
+
+ components = {
+ "unet": unet,
+ "scheduler": scheduler,
+ "vae": vae,
+ "motion_adapter": motion_adapter,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 7.5,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_motion_unet_loading(self):
+ components = self.get_dummy_components()
+ pipe = AnimateDiffPipeline(**components)
+
+ assert isinstance(pipe.unet, UNetMotionModel)
+
+ @unittest.skip("Attention slicing is not enabled in this pipeline")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ def test_inference_batch_single_identical(
+ self,
+ batch_size=2,
+ expected_max_diff=1e-4,
+ additional_params_copy_to_batched_inputs=["num_inference_steps"],
+ ):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for components in pipe.components.values():
+ if hasattr(components, "set_default_attn_processor"):
+ components.set_default_attn_processor()
+
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ inputs = self.get_dummy_inputs(torch_device)
+ # Reset generator in case it is has been used in self.get_dummy_inputs
+ inputs["generator"] = self.get_generator(0)
+
+ logger = logging.get_logger(pipe.__module__)
+ logger.setLevel(level=diffusers.logging.FATAL)
+
+ # batchify inputs
+ batched_inputs = {}
+ batched_inputs.update(inputs)
+
+ for name in self.batch_params:
+ if name not in inputs:
+ continue
+
+ value = inputs[name]
+ if name == "prompt":
+ len_prompt = len(value)
+ batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
+ batched_inputs[name][-1] = 100 * "very long"
+
+ else:
+ batched_inputs[name] = batch_size * [value]
+
+ if "generator" in inputs:
+ batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
+
+ if "batch_size" in inputs:
+ batched_inputs["batch_size"] = batch_size
+
+ for arg in additional_params_copy_to_batched_inputs:
+ batched_inputs[arg] = inputs[arg]
+
+ output = pipe(**inputs)
+ output_batch = pipe(**batched_inputs)
+
+ assert output_batch[0].shape[0] == batch_size
+
+ max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
+ assert max_diff < expected_max_diff
+
+ @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
+ def test_to_device(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ pipe.to("cpu")
+ # pipeline creates a new motion UNet under the hood. So we need to check the device from pipe.components
+ model_devices = [
+ component.device.type for component in pipe.components.values() if hasattr(component, "device")
+ ]
+ self.assertTrue(all(device == "cpu" for device in model_devices))
+
+ output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
+ self.assertTrue(np.isnan(output_cpu).sum() == 0)
+
+ pipe.to("cuda")
+ model_devices = [
+ component.device.type for component in pipe.components.values() if hasattr(component, "device")
+ ]
+ self.assertTrue(all(device == "cuda" for device in model_devices))
+
+ output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
+ self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
+
+ def test_to_dtype(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ # pipeline creates a new motion UNet under the hood. So we need to check the dtype from pipe.components
+ model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
+ self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
+
+ pipe.to(torch_dtype=torch.float16)
+ model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
+ self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
+
+
+@slow
+@require_torch_gpu
+class AnimateDiffPipelineSlowTests(unittest.TestCase):
+ def tearDown(self):
+ # clean up the VRAM after each test
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_animatediff(self):
+ adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
+ pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter)
+ pipe = pipe.to(torch_device)
+ pipe.scheduler = DDIMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="linear",
+ steps_offset=1,
+ clip_sample=False,
+ )
+ pipe.enable_vae_slicing()
+ pipe.enable_model_cpu_offload()
+ pipe.set_progress_bar_config(disable=None)
+
+ prompt = "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain"
+ negative_prompt = "bad quality, worse quality"
+
+ generator = torch.Generator("cpu").manual_seed(0)
+ output = pipe(
+ prompt,
+ negative_prompt=negative_prompt,
+ num_frames=16,
+ generator=generator,
+ guidance_scale=7.5,
+ num_inference_steps=3,
+ output_type="np",
+ )
+
+ image = output.frames[0]
+ assert image.shape == (16, 512, 512, 3)
+
+ image_slice = image[0, -3:, -3:, -1]
+ expected_slice = np.array(
+ [
+ 0.11357737,
+ 0.11285847,
+ 0.11180121,
+ 0.11084166,
+ 0.11414117,
+ 0.09785956,
+ 0.10742754,
+ 0.10510018,
+ 0.08045256,
+ ]
+ )
+ assert numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice.flatten()) < 1e-3