Skip to content

Commit

Permalink
Animatediff Proposal (huggingface#5413)
Browse files Browse the repository at this point in the history
* draft design

* clean up

* clean up

* clean up

* clean up

* clean up

* clean  up

* clean up

* clean up

* clean up

* update pipeline

* clean up

* clean up

* clean up

* add tests

* change motion block

* clean up

* clean up

* clean up

* update

* update

* update

* update

* update

* update

* update

* update

* clean up

* update

* update

* update model test

* update

* update

* update

* update

* make style

* update

* fix embeddings

* update

* merge upstream

* max fix copies

* fix bug

* fix mistake

* add docs

* update

* clean up

* update

* clean up

* clean up

* fix docstrings

* fix docstrings

* update

* update

* clean  up

* update
  • Loading branch information
DN6 authored Nov 2, 2023
1 parent c872219 commit 7f69ff6
Show file tree
Hide file tree
Showing 12 changed files with 2,611 additions and 0 deletions.
6 changes: 6 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
"AutoencoderTiny",
"ControlNetModel",
"ModelMixin",
"MotionAdapter",
"MultiAdapter",
"PriorTransformer",
"T2IAdapter",
Expand All @@ -88,6 +89,7 @@
"UNet2DConditionModel",
"UNet2DModel",
"UNet3DConditionModel",
"UNetMotionModel",
"VQModel",
]
)
Expand Down Expand Up @@ -195,6 +197,7 @@
[
"AltDiffusionImg2ImgPipeline",
"AltDiffusionPipeline",
"AnimateDiffPipeline",
"AudioLDM2Pipeline",
"AudioLDM2ProjectionModel",
"AudioLDM2UNet2DConditionModel",
Expand Down Expand Up @@ -440,6 +443,7 @@
AutoencoderTiny,
ControlNetModel,
ModelMixin,
MotionAdapter,
MultiAdapter,
PriorTransformer,
T2IAdapter,
Expand All @@ -449,6 +453,7 @@
UNet2DConditionModel,
UNet2DModel,
UNet3DConditionModel,
UNetMotionModel,
VQModel,
)
from .optimization import (
Expand Down Expand Up @@ -537,6 +542,7 @@
from .pipelines import (
AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline,
AnimateDiffPipeline,
AudioLDM2Pipeline,
AudioLDM2ProjectionModel,
AudioLDM2UNet2DConditionModel,
Expand Down
2 changes: 2 additions & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_import_structure["unet_2d"] = ["UNet2DModel"]
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["vq_model"] = ["VQModel"]

if is_flax_available():
Expand All @@ -60,6 +61,7 @@
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_condition import UNet3DConditionModel
from .unet_motion_model import MotionAdapter, UNetMotionModel
from .vq_model import VQModel

if is_flax_available():
Expand Down
22 changes: 22 additions & 0 deletions models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU
from .attention_processor import Attention
from .embeddings import SinusoidalPositionalEmbedding
from .lora import LoRACompatibleLinear
from .normalization import AdaLayerNorm, AdaLayerNormZero

Expand Down Expand Up @@ -96,6 +97,10 @@ class BasicTransformerBlock(nn.Module):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""

def __init__(
Expand All @@ -115,6 +120,8 @@ def __init__(
norm_type: str = "layer_norm",
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
):
super().__init__()
self.only_cross_attention = only_cross_attention
Expand All @@ -128,6 +135,16 @@ def __init__(
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)

if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)

if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None

# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if self.use_ada_layer_norm:
Expand Down Expand Up @@ -207,6 +224,9 @@ def forward(
else:
norm_hidden_states = self.norm1(hidden_states)

if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)

# 1. Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0

Expand Down Expand Up @@ -234,6 +254,8 @@ def forward(
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)

attn_output = self.attn2(
norm_hidden_states,
Expand Down
27 changes: 27 additions & 0 deletions models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,33 @@ def forward(self, x):
return out


class SinusoidalPositionalEmbedding(nn.Module):
"""Apply positional information to a sequence of embeddings.
Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
them
Args:
embed_dim: (int): Dimension of the positional embedding.
max_seq_length: Maximum sequence length to apply positional embeddings
"""

def __init__(self, embed_dim: int, max_seq_length: int = 32):
super().__init__()
position = torch.arange(max_seq_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
pe = torch.zeros(1, max_seq_length, embed_dim)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe)

def forward(self, x):
_, seq_length, _ = x.shape
x = x + self.pe[:, :seq_length]
return x


class ImagePositionalEmbeddings(nn.Module):
"""
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
Expand Down
8 changes: 8 additions & 0 deletions models/transformer_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
double_self_attention (`bool`, *optional*):
Configure if each `TransformerBlock` should contain two self-attention layers.
positional_embeddings: (`str`, *optional*):
The type of positional embeddings to apply to the sequence input before passing use.
num_positional_embeddings: (`int`, *optional*):
The maximum length of the sequence over which to apply positional embeddings.
"""

@register_to_config
Expand All @@ -77,6 +81,8 @@ def __init__(
activation_fn: str = "geglu",
norm_elementwise_affine: bool = True,
double_self_attention: bool = True,
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
):
super().__init__()
self.num_attention_heads = num_attention_heads
Expand All @@ -101,6 +107,8 @@ def __init__(
attention_bias=attention_bias,
double_self_attention=double_self_attention,
norm_elementwise_affine=norm_elementwise_affine,
positional_embeddings=positional_embeddings,
num_positional_embeddings=num_positional_embeddings,
)
for d in range(num_layers)
]
Expand Down
Loading

0 comments on commit 7f69ff6

Please sign in to comment.