Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Animatediff Proposal #5413

Merged
merged 56 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
d8ced0f
draft design
DN6 Oct 15, 2023
9e4c700
clean up
DN6 Oct 16, 2023
a026ea5
clean up
DN6 Oct 16, 2023
bbb2b6c
clean up
DN6 Oct 16, 2023
36b3a44
clean up
DN6 Oct 18, 2023
2db7bd3
clean up
DN6 Oct 20, 2023
72e0fa6
clean up
DN6 Oct 20, 2023
d8d3515
clean up
DN6 Oct 21, 2023
7a5fbf8
clean up
DN6 Oct 22, 2023
9eeee36
clean up
DN6 Oct 22, 2023
86a4d31
update pipeline
DN6 Oct 22, 2023
c7ba4b8
clean up
DN6 Oct 23, 2023
6ec184a
clean up
DN6 Oct 24, 2023
79f402f
clean up
DN6 Oct 24, 2023
b24f58a
add tests
DN6 Oct 25, 2023
2688d07
change motion block
DN6 Oct 25, 2023
0deab59
clean up
DN6 Oct 25, 2023
9c66c21
clean up
DN6 Oct 25, 2023
1bd65de
clean up
DN6 Oct 25, 2023
22c9f7b
update
DN6 Oct 25, 2023
0e1f7a8
update
DN6 Oct 25, 2023
c7e1b14
update
DN6 Oct 25, 2023
ee79cf3
update
DN6 Oct 25, 2023
fe3828a
update
DN6 Oct 25, 2023
bcbc2d1
update
DN6 Oct 26, 2023
4df582e
update
DN6 Oct 26, 2023
bf5b65a
update
DN6 Oct 26, 2023
3ba1ba0
clean up
DN6 Oct 26, 2023
4d0b5ec
update
DN6 Oct 26, 2023
8be5f1f
update
DN6 Oct 26, 2023
313db1d
update model test
DN6 Oct 27, 2023
e82331e
update
DN6 Oct 30, 2023
37de1de
update
DN6 Oct 31, 2023
71dc350
update
DN6 Oct 31, 2023
5d65837
update
DN6 Oct 31, 2023
2b78f1e
make style
DN6 Nov 1, 2023
3f5d8de
update
DN6 Nov 1, 2023
d939379
fix embeddings
DN6 Nov 1, 2023
9e6a146
update
DN6 Nov 1, 2023
5e43f24
Merge branch 'main' into animatediff-model
DN6 Nov 1, 2023
dc6eb04
merge upstream
DN6 Nov 1, 2023
5f003e5
max fix copies
DN6 Nov 1, 2023
6f6f8aa
fix bug
DN6 Nov 1, 2023
ec8bb6e
fix mistake
DN6 Nov 1, 2023
d41f717
add docs
DN6 Nov 1, 2023
6d81f2a
update
DN6 Nov 2, 2023
840f576
clean up
DN6 Nov 2, 2023
a6d025b
update
DN6 Nov 2, 2023
ee51b90
clean up
DN6 Nov 2, 2023
dfa52fb
clean up
DN6 Nov 2, 2023
c24c97b
fix docstrings
DN6 Nov 2, 2023
ef893c4
fix docstrings
DN6 Nov 2, 2023
a2e38cc
update
DN6 Nov 2, 2023
0d6f5be
update
DN6 Nov 2, 2023
beb1646
clean up
DN6 Nov 2, 2023
88e76c6
update
DN6 Nov 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@
title: UNet2DConditionModel
- local: api/models/unet3d-cond
title: UNet3DConditionModel
- local: api/models/unet-motion
title: UNetMotionModel
- local: api/models/vq
title: VQModel
- local: api/models/autoencoderkl
Expand All @@ -206,6 +208,8 @@
title: Overview
- local: api/pipelines/alt_diffusion
title: AltDiffusion
- local: api/pipelienes/animatediff
title: AnimateDiff
- local: api/pipelines/attend_and_excite
title: Attend-and-Excite
- local: api/pipelines/audio_diffusion
Expand Down Expand Up @@ -392,5 +396,5 @@
title: Utilities
- local: api/image_processor
title: VAE Image Processor
title: Internal classes
title: Internal classes
title: API
13 changes: 13 additions & 0 deletions docs/source/en/api/models/unet-motion.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# UNetMotionModel

The [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 2D UNet model.

The abstract from the paper is:

*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.*

## UNetMotionModel
[[autodoc]] UNetMotionModel

## UNet3DConditionOutput
[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput
63 changes: 63 additions & 0 deletions docs/source/en/api/pipelines/animatediff.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Text-to-Video Generation with AnimateDiff

## Overview

[AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725) by Yuwei Guo, Ceyuan Yang*, Anyi Rao, Yaohui Wang, Yu Qiao, Dahua Lin, Bo Dai

The abstract of the paper is the following:

With the advance of text-to-image models (e.g., Stable Diffusion) and corresponding personalization techniques such as DreamBooth and LoRA, everyone can manifest their imagination into high-quality images at an affordable cost. Subsequently, there is a great demand for image animation techniques to further combine generated static images with motion dynamics. In this report, we propose a practical framework to animate most of the existing personalized text-to-image models once and for all, saving efforts in model-specific tuning. At the core of the proposed framework is to insert a newly initialized motion modeling module into the frozen text-to-image model and train it on video clips to distill reasonable motion priors. Once trained, by simply injecting this motion modeling module, all personalized versions derived from the same base T2I readily become text-driven models that produce diverse and personalized animated images. We conduct our evaluation on several public representative personalized text-to-image models across anime pictures and realistic photographs, and demonstrate that our proposed framework helps these models generate temporally smooth animation clips while preserving the domain and diversity of their outputs. Code and pre-trained weights will be publicly available at this https URL .

## Available Pipelines:

| Pipeline | Tasks | Demo
|---|---|:---:|
| [AnimateDiffPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff.py) | *Text-to-Video Generation with AnimateDiff* |

## Usage example

AnimateDiff works with a MotionAdapter checkpoint and a Stable Diffusion model checkpoint. The MotionAdapter is a collection of Motion Modules that are responsible for adding coherent motion across image frames. These modules are applied after the Resnet and Attention blocks in Stable Diffusion UNet.

In the following we give a simple example of how to use a *MotionAdapter* checkpoint with Diffusers for inference based on StableDiffusion-1.4/1.5.

```python
import torch
from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
from diffusers.utils import export_to_gif

# Load the motion adapter
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter)
DN6 marked this conversation as resolved.
Show resolved Hide resolved
pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False)
pipe.enable_model_cpu_offload()
DN6 marked this conversation as resolved.
Show resolved Hide resolved

output = pipe(
prompt="masterpiece, best quality, 1boy, jacket, beard, walking, beanie, sunglasses, from below, looking up, fisheye, upper body, wasteland, sunset, solo focus, cloudy sky, backpack, hands in pockets",
negative_prompt="human, worst quality, low quality, letterboxed",
num_frames=16,
guidance_scale=7.5,
num_inference_steps=25,
generator = torch.Generator("cpu").manual_seed(42)
)
frames = output.frames[0]
export_to_gif(frames, "animation.gif")
```
DN6 marked this conversation as resolved.
Show resolved Hide resolved

## Available checkpoints

Motion Adapter checkpoints can be found under [guoyww/animatediff](https://huggingface.co/guoyww/).

These checkpoints will work with any model based on Stable Diffusion 1.4/1.5

6 changes: 6 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
"AutoencoderTiny",
"ControlNetModel",
"ModelMixin",
"MotionAdapter",
"MultiAdapter",
"PriorTransformer",
"T2IAdapter",
Expand All @@ -88,6 +89,7 @@
"UNet2DConditionModel",
"UNet2DModel",
"UNet3DConditionModel",
"UNetMotionModel",
"VQModel",
]
)
Expand Down Expand Up @@ -195,6 +197,7 @@
[
"AltDiffusionImg2ImgPipeline",
"AltDiffusionPipeline",
"AnimateDiffPipeline",
"AudioLDM2Pipeline",
"AudioLDM2ProjectionModel",
"AudioLDM2UNet2DConditionModel",
Expand Down Expand Up @@ -440,6 +443,7 @@
AutoencoderTiny,
ControlNetModel,
ModelMixin,
MotionAdapter,
MultiAdapter,
PriorTransformer,
T2IAdapter,
Expand All @@ -449,6 +453,7 @@
UNet2DConditionModel,
UNet2DModel,
UNet3DConditionModel,
UNetMotionModel,
VQModel,
)
from .optimization import (
Expand Down Expand Up @@ -537,6 +542,7 @@
from .pipelines import (
AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline,
AnimateDiffPipeline,
AudioLDM2Pipeline,
AudioLDM2ProjectionModel,
AudioLDM2UNet2DConditionModel,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_import_structure["unet_2d"] = ["UNet2DModel"]
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["vq_model"] = ["VQModel"]

if is_flax_available():
Expand All @@ -60,6 +61,7 @@
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_condition import UNet3DConditionModel
from .unet_motion_model import MotionAdapter, UNetMotionModel
from .vq_model import VQModel

if is_flax_available():
Expand Down
22 changes: 22 additions & 0 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU
from .attention_processor import Attention
from .embeddings import SinusoidalPositionalEmbedding
from .lora import LoRACompatibleLinear
from .normalization import AdaLayerNorm, AdaLayerNormZero

Expand Down Expand Up @@ -96,6 +97,10 @@ class BasicTransformerBlock(nn.Module):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""

def __init__(
Expand All @@ -115,6 +120,8 @@ def __init__(
norm_type: str = "layer_norm",
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
):
super().__init__()
self.only_cross_attention = only_cross_attention
Expand All @@ -128,6 +135,16 @@ def __init__(
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)

if positional_embeddings and (num_positional_embeddings is None):
DN6 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)

if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None

# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if self.use_ada_layer_norm:
Expand Down Expand Up @@ -207,6 +224,9 @@ def forward(
else:
norm_hidden_states = self.norm1(hidden_states)

if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)

# 1. Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0

Expand Down Expand Up @@ -234,6 +254,8 @@ def forward(
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)

attn_output = self.attn2(
norm_hidden_states,
Expand Down
27 changes: 27 additions & 0 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,33 @@ def forward(self, x):
return out


class SinusoidalPositionalEmbedding(nn.Module):
"""Apply positional information to a sequence of embeddings.

Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
them

Args:
embed_dim: (int): Dimension of the positional embedding.
max_seq_length: Maximum sequence length to apply positional embeddings

"""

def __init__(self, embed_dim: int, max_seq_length: int = 32):
super().__init__()
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
position = torch.arange(max_seq_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
pe = torch.zeros(1, max_seq_length, embed_dim)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe)
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved

def forward(self, x):
_, seq_length, _ = x.shape
x = x + self.pe[:, :seq_length]
return x


class ImagePositionalEmbeddings(nn.Module):
"""
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
Expand Down
14 changes: 10 additions & 4 deletions src/diffusers/models/transformer_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,17 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlock` attention should contain a bias parameter.
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
activation functions.
norm_elementwise_affine (`bool`, *optional*):
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
double_self_attention (`bool`, *optional*):
Configure if each `TransformerBlock` should contain two self-attention layers.
positional_embeddings: (`str`, *optional*):
The type of positional embeddings to apply to the sequence input before passing use.
num_positional_embeddings: (`int`, *optional*):
The maximum length of the sequence over which to apply positional embeddings.
"""

@register_to_config
Expand All @@ -67,22 +69,23 @@ def __init__(
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
DN6 marked this conversation as resolved.
Show resolved Hide resolved
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
DN6 marked this conversation as resolved.
Show resolved Hide resolved
activation_fn: str = "geglu",
norm_elementwise_affine: bool = True,
double_self_attention: bool = True,
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim

self.use_cross_attention = cross_attention_dim is not None
self.in_channels = in_channels

self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
Expand All @@ -101,6 +104,8 @@ def __init__(
attention_bias=attention_bias,
double_self_attention=double_self_attention,
norm_elementwise_affine=norm_elementwise_affine,
positional_embeddings=positional_embeddings,
num_positional_embeddings=num_positional_embeddings,
)
for d in range(num_layers)
]
Expand Down Expand Up @@ -160,6 +165,7 @@ def forward(
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)

hidden_states = self.proj_in(hidden_states)
encoder_hidden_states = encoder_hidden_states if self.use_cross_attention else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
encoder_hidden_states = encoder_hidden_states if self.use_cross_attention else None
encoder_hidden_states = encoder_hidden_states if self.use_cross_attention else None

do we need this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't understand why we need this. If cross_attention_dim is None then why do we have to manually set encodre_hidden_states to None. This looks more like a hacky bug correction. Why do we pass encoder_hidden_states in the first place if we don't have cross attention?

Copy link
Collaborator Author

@DN6 DN6 Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten
Initially added it in so that users could train new motion modules with the option of using cross attention.

We can omit sending the encoder hidden state to this block from the higher level blocks. It just means the MotionModules in the UNetMotionModel cannot support cross attention at all. We can then remove the temporal_cross_attention_dim argument.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should only add options that are needed to run the official animate diff checkpoints, all possible customizations that the user could try out should not be added (only if it becomes necessary)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here if the motion modules all have temporal_cross_attention_dim always set to None, the let's not give the possibility to customize it as this unnecessarily bloats the code


# 2. Blocks
for block in self.transformer_blocks:
Expand Down
Loading
Loading