From d8ced0ff7aa6b9e5c66348b8f3ed84e3bae3f7cd Mon Sep 17 00:00:00 2001 From: DN6 Date: Sun, 15 Oct 2023 21:13:16 +0530 Subject: [PATCH 01/55] draft design --- src/diffusers/models/__init__.py | 2 + src/diffusers/models/embeddings.py | 16 + src/diffusers/models/unet_2d_condition.py | 7 +- src/diffusers/models/unet_motion_blocks.py | 1092 ++++++++++++++++++++ src/diffusers/models/unet_motion_model.py | 730 +++++++++++++ 5 files changed, 1841 insertions(+), 6 deletions(-) create mode 100644 src/diffusers/models/unet_motion_blocks.py create mode 100644 src/diffusers/models/unet_motion_model.py diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index a5d0066d5c40..af59749a3338 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -35,6 +35,7 @@ _import_structure["unet_2d"] = ["UNet2DModel"] _import_structure["unet_2d_condition"] = ["UNet2DConditionModel"] _import_structure["unet_3d_condition"] = ["UNet3DConditionModel"] + _import_structure["unet_motion_model"] = ["UNetMotionModel"] _import_structure["vq_model"] = ["VQModel"] if is_flax_available(): @@ -60,6 +61,7 @@ from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel from .unet_3d_condition import UNet3DConditionModel + from .unet_motion_model import UNetMotionModel from .vq_model import VQModel if is_flax_available(): diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index e05092de3d10..6fdca133b6a9 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -249,6 +249,22 @@ def forward(self, x): return out +class PositionalEmbedding(nn.Module): + def __init__(self, embed_dim, max_seq_length=24): + super().__init__() + position = torch.arange(max_seq_length).unsqueeze(1) + + div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)) + self.pos_embed = torch.zeros(1, max_seq_length, embed_dim) + self.pos_embed[0, :, 0::2] = torch.sin(position * div_term) + self.pos_embed[0, :, 1::2] = torch.cos(position * div_term) + + def forward(self, x): + seq_len = x.shape[1] + x = x + self.pos_embed[:, :seq_len] + return x + + class ImagePositionalEmbeddings(nn.Module): """ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index f858a7685360..43dc12248e99 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -42,12 +42,7 @@ Timesteps, ) from .modeling_utils import ModelMixin -from .unet_2d_blocks import ( - UNetMidBlock2DCrossAttn, - UNetMidBlock2DSimpleCrossAttn, - get_down_block, - get_up_block, -) +from .unet_2d_blocks import UNetMidBlock2DCrossAttn, UNetMidBlock2DSimpleCrossAttn, get_down_block, get_up_block logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/unet_motion_blocks.py b/src/diffusers/models/unet_motion_blocks.py new file mode 100644 index 000000000000..38a1a6368207 --- /dev/null +++ b/src/diffusers/models/unet_motion_blocks.py @@ -0,0 +1,1092 @@ +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import is_torch_version +from ..utils.torch_utils import apply_freeu +from .dual_transformer_2d import DualTransformer2DModel +from .embeddings import PositionalEmbedding +from .modeling_utils import ModelMixin +from .resnet import Downsample2D, ResnetBlock2D, Upsample2D +from .transformer_2d import Transformer2DModel +from .transformer_temporal import TransformerTemporalModel + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + num_attention_heads, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=True, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + motion_norm_num_groups=32, + motion_cross_attention_dim=None, + motion_num_attention_heads=8, + motion_attention_head_dim=40, + motion_attenion_bias=False, + motion_activation_fn="geglu", + motion_max_seq_length=24, +): + if down_block_type == "DownBlockMotion": + return DownBlockMotion( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + motion_norm_num_groups=motion_norm_num_groups, + motion_cross_attention_dim=motion_cross_attention_dim, + motion_num_attention_heads=motion_num_attention_heads, + motion_attention_head_dim=motion_attention_head_dim, + motion_attenion_bias=motion_attenion_bias, + motion_activation_fn=motion_activation_fn, + motion_max_seq_length=motion_max_seq_length, + ) + elif down_block_type == "CrossAttnDownBlockMotion": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion") + return CrossAttnDownBlockMotion( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + motion_norm_num_groups=motion_norm_num_groups, + motion_cross_attention_dim=motion_cross_attention_dim, + motion_num_attention_heads=motion_num_attention_heads, + motion_attention_head_dim=motion_attention_head_dim, + motion_attenion_bias=motion_attenion_bias, + motion_activation_fn=motion_activation_fn, + motion_max_seq_length=motion_max_seq_length, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + num_attention_heads, + resolution_idx=None, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=True, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + motion_norm_num_groups=32, + motion_cross_attention_dim=None, + motion_num_attention_heads=8, + motion_attention_head_dim=40, + motion_attenion_bias=False, + motion_activation_fn="geglu", + motion_max_seq_length=24, +): + if up_block_type == "UpBlockMotion": + return UpBlockMotion( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + resolution_idx=resolution_idx, + motion_norm_num_groups=motion_norm_num_groups, + motion_cross_attention_dim=motion_cross_attention_dim, + motion_num_attention_heads=motion_num_attention_heads, + motion_attention_head_dim=motion_attention_head_dim, + motion_attenion_bias=motion_attenion_bias, + motion_activation_fn=motion_activation_fn, + motion_max_seq_length=motion_max_seq_length, + ) + elif up_block_type == "CrossAttnUpBlockMotion": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion") + return CrossAttnUpBlockMotion( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resolution_idx=resolution_idx, + motion_norm_num_groups=motion_norm_num_groups, + motion_cross_attention_dim=motion_cross_attention_dim, + motion_num_attention_heads=motion_num_attention_heads, + motion_attention_head_dim=motion_attention_head_dim, + motion_attenion_bias=motion_attenion_bias, + motion_activation_fn=motion_activation_fn, + motion_max_seq_length=motion_max_seq_length, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class MotionBlock(nn.Module): + def __init__( + self, + in_channels, + norm_num_groups=32, + cross_attention_dim=None, + activation_fn="geglu", + attention_bias=False, + num_attention_heads=8, + max_seq_length=24, + ) -> None: + super().__init__() + + self.pos_embed = PositionalEmbedding(embed_dim=in_channels, max_seq_length=max_seq_length) + self.temporal_transformer = TransformerTemporalModel( + in_channels=in_channels, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=in_channels // num_attention_heads, + activation_fn=activation_fn, + attention_bias=attention_bias, + num_attention_heads=num_attention_heads, + ) + + def forward(self, x): + x = self.pos_embed(x) + x = self.temporal_transformer(x) + """ + x = x.reshape([-1, self.max_seq_length, *x.shape[1:]]) + x = x.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) + + x = x.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) + x = x.flatten(0, 1) + """ + return x + + +class MotionModules(nn.Module): + def __init__( + self, + in_channels, + layers_per_block=2, + num_attention_heads=8, + attention_bias=False, + cross_attention_dim=None, + activation_fn="geglu", + norm_num_groups=32, + max_seq_length=24, + ): + super().__init__() + self.motion_modules = nn.ModuleList([]) + + for i in range(layers_per_block): + self.motion_modules.append( + MotionBlock( + in_channels=in_channels, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + num_attention_heads=num_attention_heads, + max_seq_length=max_seq_length, + ) + ) + + +class MotionAdapter(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + block_out_channels=(320, 640, 1280, 1280), + down_block_types=( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), + layers_per_block=2, + num_attention_heads=8, + attenion_bias=False, + cross_attention_dim=None, + activation_fn="geglu", + norm_num_groups=32, + max_seq_length=24, + ): + """Container to store Motion Modules + + Args: + block_out_channels (tuple, optional): _description_. Defaults to (320, 640, 1280, 1280). + down_block_types (tuple, optional): _description_. Defaults to ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ). + up_block_types (tuple, optional): _description_. Defaults to ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"). + layers_per_block (int, optional): _description_. Defaults to 2. + num_attention_heads (int, optional): _description_. Defaults to 8. + attention_head_dim (int, optional): _description_. Defaults to 40. + attenion_bias (bool, optional): _description_. Defaults to False. + cross_attention_dim (_type_, optional): _description_. Defaults to None. + activation_fn (str, optional): _description_. Defaults to "geglu". + norm_num_groups (int, optional): _description_. Defaults to 32. + max_seq_length (int, optional): _description_. Defaults to 24. + """ + + super().__init__() + down_blocks = [] + up_blocks = [] + + for i, block_type in enumerate(down_block_types): + output_channel = block_out_channels[i] + down_blocks.append( + MotionModules( + in_channels=output_channel, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attenion_bias, + num_attention_heads=num_attention_heads, + max_seq_length=max_seq_length, + ) + ) + + self.mid_block = MotionModules( + in_channels=block_out_channels[-1], + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attenion_bias, + num_attention_heads=num_attention_heads, + layers_per_block=1, + ) + + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, block_type in enumerate(up_block_types): + output_channel = reversed_block_out_channels[i] + up_blocks.append( + MotionModules( + in_channels=output_channel, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attenion_bias, + num_attention_heads=num_attention_heads, + max_seq_length=max_seq_length, + layers_per_block=layers_per_block + 1, + ) + ) + + self.down_blocks = nn.ModuleList(down_blocks) + self.up_blocks = nn.ModuleList(up_blocks) + + def forward(self, sample): + pass + + +class DownBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + motion_norm_num_groups=32, + motion_cross_attention_dim=None, + motion_num_attention_heads=8, + motion_attention_head_dim=40, + motion_attenion_bias=False, + motion_activation_fn="geglu", + motion_max_seq_length=24, + ): + super().__init__() + resnets = [] + motion_modules = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + motion_modules.append( + MotionBlock( + in_channels=out_channels, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=motion_cross_attention_dim, + activation_fn=motion_activation_fn, + attention_bias=motion_attenion_bias, + num_attention_heads=motion_num_attention_heads, + max_seq_length=motion_max_seq_length, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, scale: float = 1.0): + output_states = () + + for resnet, motion_module in self.resnets, self.motion_modules: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = motion_module(hidden_states) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale=scale) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + attention_type="default", + motion_norm_num_groups=32, + motion_cross_attention_dim=None, + motion_num_attention_heads=8, + motion_attention_head_dim=40, + motion_attenion_bias=False, + motion_activation_fn="geglu", + motion_max_seq_length=24, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = MotionModules( + in_channels=out_channels, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=motion_cross_attention_dim, + activation_fn=motion_activation_fn, + attention_bias=motion_attenion_bias, + num_attention_heads=motion_num_attention_heads, + max_seq_length=motion_max_seq_length, + ).motion_modules + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + additional_residuals=None, + ): + output_states = () + + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) + + for i, (resnet, attn, motion_module) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = motion_module(hidden_states) + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale=lora_scale) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: int = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + attention_type="default", + motion_norm_num_groups=32, + motion_cross_attention_dim=None, + motion_num_attention_heads=8, + motion_attention_head_dim=40, + motion_attenion_bias=False, + motion_activation_fn="geglu", + motion_max_seq_length=24, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + motion_modules.append( + MotionBlock( + in_channels=out_channels, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=motion_cross_attention_dim, + activation_fn=motion_activation_fn, + attention_bias=motion_attenion_bias, + num_attention_heads=motion_num_attention_heads, + max_seq_length=motion_max_seq_length, + ) + ) + + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = MotionModules( + in_channels=out_channels, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=motion_cross_attention_dim, + activation_fn=motion_activation_fn, + attention_bias=motion_attenion_bias, + num_attention_heads=motion_num_attention_heads, + max_seq_length=motion_max_seq_length, + ).motion_modules + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = motion_module(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) + + return hidden_states + + +class UpBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: int = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + motion_norm_num_groups=32, + motion_cross_attention_dim=None, + motion_num_attention_heads=8, + motion_attention_head_dim=40, + motion_attenion_bias=False, + motion_activation_fn="geglu", + motion_max_seq_length=24, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = MotionModules( + in_channels=out_channels, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=motion_cross_attention_dim, + activation_fn=motion_activation_fn, + attention_bias=motion_attenion_bias, + num_attention_heads=motion_num_attention_heads, + max_seq_length=motion_max_seq_length, + ).motion_modules + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + blocks = zip(self.resnets, self.motion_modules) + + for resnet, motion_module in blocks: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = motion_module(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size, scale=scale) + + return hidden_states + + +class UNetMidBlockCrossAttnMotion(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + attention_type="default", + motion_norm_num_groups=32, + motion_cross_attention_dim=None, + motion_num_attention_heads=8, + motion_attention_head_dim=40, + motion_attenion_bias=False, + motion_activation_fn="geglu", + motion_max_seq_length=24, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = MotionModules( + in_channels=in_channels, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=motion_cross_attention_dim, + activation_fn=motion_activation_fn, + attention_bias=motion_attenion_bias, + num_attention_heads=motion_num_attention_heads, + max_seq_length=motion_max_seq_length, + layers_per_block=num_layers, + ).motion_modules + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) + for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = motion_module(hidden_states) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + + return hidden_states diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py new file mode 100644 index 000000000000..c1595c42cddf --- /dev/null +++ b/src/diffusers/models/unet_motion_model.py @@ -0,0 +1,730 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, logging +from .attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from .embeddings import ( + TimestepEmbedding, + Timesteps, +) +from .modeling_utils import ModelMixin +from .unet_3d_condition import UNet3DConditionOutput +from .unet_motion_blocks import ( + CrossAttnDownBlockMotion, + CrossAttnUpBlockMotion, + DownBlockMotion, + UNetMidBlockCrossAttnMotion, + UpBlockMotion, + get_down_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None + + +class UNetMotionModel(ModelMixin, ConfigMixin): + r""" + A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + in_channels: int = 4, + out_channels: int = 4, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlockMotion", + "CrossAttnDownBlockMotion", + "CrossAttnDownBlockMotion", + "DownBlockMotion", + ), + up_block_types: Tuple[str] = ( + "UpBlockMotion", + "CrossAttnUpBlockMotion", + "CrossAttnUpBlockMotion", + "CrossAttnUpBlockMotion", + ), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + use_linear_projection: bool = False, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + motion_norm_num_groups=32, + motion_cross_attention_dim=None, + motion_num_attention_heads=8, + motion_attenion_bias=False, + motion_activation_fn="geglu", + motion_max_seq_length=24, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise NotImplementedError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_kernel = 3 + conv_out_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], True, 0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + # class embedding + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + use_linear_projection=use_linear_projection, + dual_cross_attention=False, + motion_norm_num_groups=motion_norm_num_groups, + motion_cross_attention_dim=motion_cross_attention_dim, + motion_num_attention_heads=motion_num_attention_heads, + motion_attenion_bias=motion_attenion_bias, + motion_activation_fn=motion_activation_fn, + motion_max_seq_length=motion_max_seq_length, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlockCrossAttnMotion( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + motion_norm_num_groups=motion_norm_num_groups, + motion_cross_attention_dim=motion_cross_attention_dim, + motion_num_attention_heads=motion_num_attention_heads, + motion_attenion_bias=motion_attenion_bias, + motion_activation_fn=motion_activation_fn, + motion_max_seq_length=motion_max_seq_length, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=False, + resolution_idx=i, + use_linear_projection=use_linear_projection, + motion_norm_num_groups=motion_norm_num_groups, + motion_cross_attention_dim=motion_cross_attention_dim, + motion_num_attention_heads=motion_num_attention_heads, + motion_attenion_bias=motion_attenion_bias, + motion_activation_fn=motion_activation_fn, + motion_max_seq_length=motion_max_seq_length, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + self.conv_act = nn.SiLU() + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + @classmethod + def from_unet2d(cls, unet, motion_modules=None, **kwargs): + # based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459 + config = unet.config + config["_class_name"] = cls.__name__ + + down_blocks = [] + for down_blocks_type in config["down_block_types"]: + if "CrossAttn" in down_blocks_type: + down_blocks.append("CrossAttnDownBlockMotion") + else: + down_blocks.append("DownBlockMotion") + config["down_block_types"] = down_blocks + + up_blocks = [] + for down_blocks_type in config["up_block_types"]: + if "CrossAttn" in down_blocks_type: + up_blocks.append("CrossAttnUpBlockMotion") + else: + up_blocks.append("UpBlockMotion") + config["up_block_types"] = up_blocks + + state_dict = unet.state_dict() + if motion_modules is not None: + state_dict.update(motion_modules.state_dict()) + + model = cls.from_config(config) + model.load_state_dict(state_dict, strict=False) + + return model + + def load_motion_modules(self, motion_modules): + self.state_dict().update(motion_modules.state_dict()) + + def save_motion_modules(self): + return + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor, _remove_lora=_remove_lora) + else: + module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def enable_forward_chunking(self, chunk_size=None, dim=0): + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + def disable_forward_chunking(self): + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, None, 0) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor, _remove_lora=True) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)): + module.gradient_checkpointing = value + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu + def enable_freeu(self, s1, s2, b1, b2): + r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None: + setattr(upsample_block, k, None) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + The [`UNet3DConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, num_frames, height, width`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + + Returns: + [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + num_frames = sample.shape[2] + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + emb = emb.repeat_interleave(repeats=num_frames, dim=0) + encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) + + # 2. pre-process + sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) + + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + num_frames=num_frames, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = self.conv_out(sample) + + # reshape to (batch, channel, framerate, width, height) + sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) From 9e4c70044183f79ec3c54a367e6e72643fde2799 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 16 Oct 2023 12:14:59 +0000 Subject: [PATCH 02/55] clean up --- src/diffusers/models/embeddings.py | 15 +- src/diffusers/models/unet_motion_blocks.py | 106 ++-- src/diffusers/models/unet_motion_model.py | 19 +- src/diffusers/pipelines/__init__.py | 2 + .../pipelines/animatediff/__init__.py | 46 ++ .../animatediff/pipeline_animatediff.py | 510 ++++++++++++++++++ 6 files changed, 630 insertions(+), 68 deletions(-) create mode 100644 src/diffusers/pipelines/animatediff/__init__.py create mode 100644 src/diffusers/pipelines/animatediff/pipeline_animatediff.py diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 6fdca133b6a9..2f310ff151c8 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -250,18 +250,19 @@ def forward(self, x): class PositionalEmbedding(nn.Module): - def __init__(self, embed_dim, max_seq_length=24): + def __init__(self, embed_dim: int, max_seq_length: int = 24): super().__init__() position = torch.arange(max_seq_length).unsqueeze(1) - div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)) - self.pos_embed = torch.zeros(1, max_seq_length, embed_dim) - self.pos_embed[0, :, 0::2] = torch.sin(position * div_term) - self.pos_embed[0, :, 1::2] = torch.cos(position * div_term) + + pe = torch.zeros(max_seq_length, 1, embed_dim) + pe[:, 0, 0::2] = torch.sin(position * div_term) + pe[:, 0, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) def forward(self, x): - seq_len = x.shape[1] - x = x + self.pos_embed[:, :seq_len] + seq_length = x.shape[0] + x = x + self.pe[:seq_length] return x diff --git a/src/diffusers/models/unet_motion_blocks.py b/src/diffusers/models/unet_motion_blocks.py index 38a1a6368207..993125b43990 100644 --- a/src/diffusers/models/unet_motion_blocks.py +++ b/src/diffusers/models/unet_motion_blocks.py @@ -35,8 +35,7 @@ def get_down_block( motion_norm_num_groups=32, motion_cross_attention_dim=None, motion_num_attention_heads=8, - motion_attention_head_dim=40, - motion_attenion_bias=False, + motion_attention_bias=False, motion_activation_fn="geglu", motion_max_seq_length=24, ): @@ -55,8 +54,7 @@ def get_down_block( motion_norm_num_groups=motion_norm_num_groups, motion_cross_attention_dim=motion_cross_attention_dim, motion_num_attention_heads=motion_num_attention_heads, - motion_attention_head_dim=motion_attention_head_dim, - motion_attenion_bias=motion_attenion_bias, + motion_attention_bias=motion_attention_bias, motion_activation_fn=motion_activation_fn, motion_max_seq_length=motion_max_seq_length, ) @@ -83,8 +81,7 @@ def get_down_block( motion_norm_num_groups=motion_norm_num_groups, motion_cross_attention_dim=motion_cross_attention_dim, motion_num_attention_heads=motion_num_attention_heads, - motion_attention_head_dim=motion_attention_head_dim, - motion_attenion_bias=motion_attenion_bias, + motion_attention_bias=motion_attention_bias, motion_activation_fn=motion_activation_fn, motion_max_seq_length=motion_max_seq_length, ) @@ -113,8 +110,7 @@ def get_up_block( motion_norm_num_groups=32, motion_cross_attention_dim=None, motion_num_attention_heads=8, - motion_attention_head_dim=40, - motion_attenion_bias=False, + motion_attention_bias=False, motion_activation_fn="geglu", motion_max_seq_length=24, ): @@ -134,8 +130,7 @@ def get_up_block( motion_norm_num_groups=motion_norm_num_groups, motion_cross_attention_dim=motion_cross_attention_dim, motion_num_attention_heads=motion_num_attention_heads, - motion_attention_head_dim=motion_attention_head_dim, - motion_attenion_bias=motion_attenion_bias, + motion_attention_bias=motion_attention_bias, motion_activation_fn=motion_activation_fn, motion_max_seq_length=motion_max_seq_length, ) @@ -163,8 +158,7 @@ def get_up_block( motion_norm_num_groups=motion_norm_num_groups, motion_cross_attention_dim=motion_cross_attention_dim, motion_num_attention_heads=motion_num_attention_heads, - motion_attention_head_dim=motion_attention_head_dim, - motion_attenion_bias=motion_attenion_bias, + motion_attention_bias=motion_attention_bias, motion_activation_fn=motion_activation_fn, motion_max_seq_length=motion_max_seq_length, ) @@ -195,16 +189,15 @@ def __init__( num_attention_heads=num_attention_heads, ) - def forward(self, x): + def forward(self, x, num_frames=1): + batch_frames, channels, height, width = x.shape + + x = x.permute(0, 2, 3, 1).reshape(batch_frames, height * width, channels) x = self.pos_embed(x) - x = self.temporal_transformer(x) - """ - x = x.reshape([-1, self.max_seq_length, *x.shape[1:]]) - x = x.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) + x = x.reshape(batch_frames, height, width, channels).permute(0, 3, 1, 2) + + x = self.temporal_transformer(x, num_frames=num_frames) - x = x.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) - x = x.flatten(0, 1) - """ return x @@ -261,8 +254,11 @@ def __init__( Args: block_out_channels (tuple, optional): _description_. Defaults to (320, 640, 1280, 1280). - down_block_types (tuple, optional): _description_. Defaults to ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ). - up_block_types (tuple, optional): _description_. Defaults to ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"). + down_block_types (tuple, optional): + _description_. Defaults to ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", + "DownBlock2D", ). + up_block_types (tuple, optional): + _description_. Defaults to ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"). layers_per_block (int, optional): _description_. Defaults to 2. num_attention_heads (int, optional): _description_. Defaults to 8. attention_head_dim (int, optional): _description_. Defaults to 40. @@ -344,8 +340,7 @@ def __init__( motion_norm_num_groups=32, motion_cross_attention_dim=None, motion_num_attention_heads=8, - motion_attention_head_dim=40, - motion_attenion_bias=False, + motion_attention_bias=False, motion_activation_fn="geglu", motion_max_seq_length=24, ): @@ -375,7 +370,7 @@ def __init__( norm_num_groups=motion_norm_num_groups, cross_attention_dim=motion_cross_attention_dim, activation_fn=motion_activation_fn, - attention_bias=motion_attenion_bias, + attention_bias=motion_attention_bias, num_attention_heads=motion_num_attention_heads, max_seq_length=motion_max_seq_length, ) @@ -397,10 +392,11 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None, scale: float = 1.0): + def forward(self, hidden_states, temb=None, scale: float = 1.0, num_frames=1): output_states = () - for resnet, motion_module in self.resnets, self.motion_modules: + blocks = zip(self.resnets, self.motion_modules) + for resnet, motion_module in blocks: if self.training and self.gradient_checkpointing: def create_custom_forward(module): @@ -419,7 +415,7 @@ def custom_forward(*inputs): ) else: hidden_states = resnet(hidden_states, temb, scale=scale) - hidden_states = motion_module(hidden_states) + hidden_states = motion_module(hidden_states, num_frames=num_frames) output_states = output_states + (hidden_states,) @@ -459,8 +455,7 @@ def __init__( motion_norm_num_groups=32, motion_cross_attention_dim=None, motion_num_attention_heads=8, - motion_attention_head_dim=40, - motion_attenion_bias=False, + motion_attention_bias=False, motion_activation_fn="geglu", motion_max_seq_length=24, ): @@ -521,7 +516,7 @@ def __init__( norm_num_groups=motion_norm_num_groups, cross_attention_dim=motion_cross_attention_dim, activation_fn=motion_activation_fn, - attention_bias=motion_attenion_bias, + attention_bias=motion_attention_bias, num_attention_heads=motion_num_attention_heads, max_seq_length=motion_max_seq_length, ).motion_modules @@ -541,12 +536,13 @@ def __init__( def forward( self, - hidden_states: torch.FloatTensor, - temb: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + num_frames=1, + encoder_attention_mask=None, + cross_attention_kwargs=None, additional_residuals=None, ): output_states = () @@ -554,7 +550,6 @@ def forward( lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) - for i, (resnet, attn, motion_module) in enumerate(blocks): if self.training and self.gradient_checkpointing: @@ -592,7 +587,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = motion_module(hidden_states) + hidden_states = motion_module(hidden_states, num_frames=num_frames) # apply additional residuals to the output of the last pair of resnet and attention blocks if i == len(blocks) - 1 and additional_residuals is not None: @@ -637,8 +632,7 @@ def __init__( motion_norm_num_groups=32, motion_cross_attention_dim=None, motion_num_attention_heads=8, - motion_attention_head_dim=40, - motion_attenion_bias=False, + motion_attention_bias=False, motion_activation_fn="geglu", motion_max_seq_length=24, ): @@ -675,7 +669,7 @@ def __init__( norm_num_groups=motion_norm_num_groups, cross_attention_dim=motion_cross_attention_dim, activation_fn=motion_activation_fn, - attention_bias=motion_attenion_bias, + attention_bias=motion_attention_bias, num_attention_heads=motion_num_attention_heads, max_seq_length=motion_max_seq_length, ) @@ -714,7 +708,7 @@ def __init__( norm_num_groups=motion_norm_num_groups, cross_attention_dim=motion_cross_attention_dim, activation_fn=motion_activation_fn, - attention_bias=motion_attenion_bias, + attention_bias=motion_attention_bias, num_attention_heads=motion_num_attention_heads, max_seq_length=motion_max_seq_length, ).motion_modules @@ -737,6 +731,7 @@ def forward( upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, + num_frames=1, ): lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 is_freeu_enabled = ( @@ -746,7 +741,8 @@ def forward( and getattr(self, "b2", None) ) - for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): + blocks = zip(self.resnets, self.attentions, self.motion_modules) + for resnet, attn, motion_module in blocks: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] @@ -801,7 +797,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = motion_module(hidden_states) + hidden_states = motion_module(hidden_states, num_frames=num_frames) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -831,7 +827,7 @@ def __init__( motion_cross_attention_dim=None, motion_num_attention_heads=8, motion_attention_head_dim=40, - motion_attenion_bias=False, + motion_attention_bias=False, motion_activation_fn="geglu", motion_max_seq_length=24, ): @@ -863,7 +859,7 @@ def __init__( norm_num_groups=motion_norm_num_groups, cross_attention_dim=motion_cross_attention_dim, activation_fn=motion_activation_fn, - attention_bias=motion_attenion_bias, + attention_bias=motion_attention_bias, num_attention_heads=motion_num_attention_heads, max_seq_length=motion_max_seq_length, ).motion_modules @@ -876,7 +872,9 @@ def __init__( self.gradient_checkpointing = False self.resolution_idx = resolution_idx - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): + def forward( + self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0, num_frames=1 + ): is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) @@ -884,7 +882,7 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_si and getattr(self, "b2", None) ) - blocks = zip(self.resnets, self.motion_modules) + blocks = zip(self.resnets, self.motion_modules, self.motion_modules) for resnet, motion_module in blocks: # pop res hidden states @@ -923,7 +921,7 @@ def custom_forward(*inputs): ) else: hidden_states = resnet(hidden_states, temb, scale=scale) - hidden_states = motion_module(hidden_states) + hidden_states = motion_module(hidden_states, num_frames=num_frames) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -955,8 +953,7 @@ def __init__( motion_norm_num_groups=32, motion_cross_attention_dim=None, motion_num_attention_heads=8, - motion_attention_head_dim=40, - motion_attenion_bias=False, + motion_attention_bias=False, motion_activation_fn="geglu", motion_max_seq_length=24, ): @@ -1031,7 +1028,7 @@ def __init__( norm_num_groups=motion_norm_num_groups, cross_attention_dim=motion_cross_attention_dim, activation_fn=motion_activation_fn, - attention_bias=motion_attenion_bias, + attention_bias=motion_attention_bias, num_attention_heads=motion_num_attention_heads, max_seq_length=motion_max_seq_length, layers_per_block=num_layers, @@ -1047,6 +1044,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, + num_frames=1, ) -> torch.FloatTensor: lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) @@ -1086,7 +1084,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = motion_module(hidden_states) + hidden_states = motion_module(hidden_states, num_frames=num_frames) hidden_states = resnet(hidden_states, temb, scale=lora_scale) return hidden_states diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index c1595c42cddf..3ecd8e28611b 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -62,8 +62,8 @@ class UNet2DConditionOutput(BaseOutput): class UNetMotionModel(ModelMixin, ConfigMixin): r""" - A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample - shaped output. + A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a + sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). @@ -282,7 +282,7 @@ def __init__( ) @classmethod - def from_unet2d(cls, unet, motion_modules=None, **kwargs): + def from_unet2d(cls, unet, motion_adapter=None, **kwargs): # based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459 config = unet.config config["_class_name"] = cls.__name__ @@ -304,8 +304,8 @@ def from_unet2d(cls, unet, motion_modules=None, **kwargs): config["up_block_types"] = up_blocks state_dict = unet.state_dict() - if motion_modules is not None: - state_dict.update(motion_modules.state_dict()) + if motion_adapter is not None: + state_dict.update(motion_adapter.state_dict()) model = cls.from_config(config) model.load_state_dict(state_dict, strict=False) @@ -313,10 +313,14 @@ def from_unet2d(cls, unet, motion_modules=None, **kwargs): return model def load_motion_modules(self, motion_modules): - self.state_dict().update(motion_modules.state_dict()) + state_dict = self.state_dict() + motion_modules_state_dict = motion_modules.state_dict() + state_dict.update(motion_modules_state_dict) + + self.load_state_dict(state_dict) def save_motion_modules(self): - return + raise NotImplementedError @property # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors @@ -540,6 +544,7 @@ def disable_freeu(self): if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None: setattr(upsample_block, k, None) + # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.forward def forward( self, sample: torch.FloatTensor, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 19fe2f72d447..f7efb009f2c3 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -62,6 +62,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["alt_diffusion"] = ["AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline"] + _import_structure["animatediff"] = ["AnimateDiffPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ "AudioLDM2Pipeline", @@ -290,6 +291,7 @@ from ..utils.dummy_torch_and_transformers_objects import * else: from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline + from .animatediff import AnimateDiffPipeline from .audioldm import AudioLDMPipeline from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel from .blip_diffusion import BlipDiffusionPipeline diff --git a/src/diffusers/pipelines/animatediff/__init__.py b/src/diffusers/pipelines/animatediff/__init__.py new file mode 100644 index 000000000000..503352fec865 --- /dev/null +++ b/src/diffusers/pipelines/animatediff/__init__.py @@ -0,0 +1,46 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline", "AnimateDiffPipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + + else: + from .pipeline_animatediff import AnimateDiffPipeline, AnimateDiffPipelineOutput + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py new file mode 100644 index 000000000000..2ca4602545b1 --- /dev/null +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -0,0 +1,510 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel, UNetMotionModel +from ...models.unet_motion_blocks import MotionAdapter +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import BaseOutput, logging +from ..pipeline_utils import DiffusionPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import TextToVideoSDPipeline + >>> from diffusers.utils import export_to_video + + >>> pipe = TextToVideoSDPipeline.from_pretrained( + ... "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16" + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "Spiderman is surfing" + >>> video_frames = pipe(prompt).frames + >>> video_path = export_to_video(video_frames) + >>> video_path + ``` +""" + + +def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: + # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 + # reshape to ncfhw + mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) + std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) + # unnormalize back to [0,1] + video = video.mul_(std).add_(mean) + video.clamp_(0, 1) + # prepare the final outputs + i, c, f, h, w = video.shape + images = video.permute(2, 3, 0, 4, 1).reshape( + f, h, i * w, c + ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) + images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames) + images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c + return images + + +@dataclass +class AnimateDiffPipelineOutput(BaseOutput): + frames: Union[torch.Tensor, np.ndarray] + + +class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-video generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet3DConditionModel`]): + A [`UNet3DConditionModel`] to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + model_cpu_offload_seq = "text_encoder->unet->vae" + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + motion_adapter: MotionAdapter, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + ): + super().__init__() + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + motion_adapter=motion_adapter, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + image = self.vae.decode(latents).sample + video = ( + image[None, :] + .reshape( + ( + batch_size, + num_frames, + -1, + ) + + image.shape[2:] + ) + .permute(0, 2, 1, 3, 4) + ) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + shape = shape + # shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + num_frames: Optional[int], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "tensor", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. Choose between `torch.FloatTensor` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + + Returns: + [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_images_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # Post-processing + video = self.decode_latents(latents) + + if not return_dict: + return (video,) + + # Offload all models + self.maybe_free_model_hooks() + + return AnimateDiffPipelineOutput(frames=video) From a026ea5024fa8829d64be35b27eee3a7a255d4bf Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 16 Oct 2023 20:25:52 +0530 Subject: [PATCH 03/55] clean up --- src/diffusers/models/embeddings.py | 9 ++-- src/diffusers/models/unet_motion_blocks.py | 49 +++++++--------------- src/diffusers/models/unet_motion_model.py | 33 ++++++++------- 3 files changed, 37 insertions(+), 54 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 2f310ff151c8..b3bd37fa557b 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -255,14 +255,13 @@ def __init__(self, embed_dim: int, max_seq_length: int = 24): position = torch.arange(max_seq_length).unsqueeze(1) div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)) - pe = torch.zeros(max_seq_length, 1, embed_dim) - pe[:, 0, 0::2] = torch.sin(position * div_term) - pe[:, 0, 1::2] = torch.cos(position * div_term) - self.register_buffer("pe", pe) + self.pe = torch.zeros(max_seq_length, 1, embed_dim) + self.pe[:, 0, 0::2] = torch.sin(position * div_term) + self.pe[:, 0, 1::2] = torch.cos(position * div_term) def forward(self, x): seq_length = x.shape[0] - x = x + self.pe[:seq_length] + x = x + self.pe[:seq_length].to(x.device) return x diff --git a/src/diffusers/models/unet_motion_blocks.py b/src/diffusers/models/unet_motion_blocks.py index 993125b43990..1bc86e26c533 100644 --- a/src/diffusers/models/unet_motion_blocks.py +++ b/src/diffusers/models/unet_motion_blocks.py @@ -254,14 +254,10 @@ def __init__( Args: block_out_channels (tuple, optional): _description_. Defaults to (320, 640, 1280, 1280). - down_block_types (tuple, optional): - _description_. Defaults to ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", - "DownBlock2D", ). - up_block_types (tuple, optional): - _description_. Defaults to ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"). + down_block_types (tuple, optional): _description_. Defaults to ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ). + up_block_types (tuple, optional): _description_. Defaults to ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"). layers_per_block (int, optional): _description_. Defaults to 2. num_attention_heads (int, optional): _description_. Defaults to 8. - attention_head_dim (int, optional): _description_. Defaults to 40. attenion_bias (bool, optional): _description_. Defaults to False. cross_attention_dim (_type_, optional): _description_. Defaults to None. activation_fn (str, optional): _description_. Defaults to "geglu". @@ -284,6 +280,7 @@ def __init__( attention_bias=attenion_bias, num_attention_heads=num_attention_heads, max_seq_length=max_seq_length, + layers_per_block=layers_per_block, ) ) @@ -346,7 +343,6 @@ def __init__( ): super().__init__() resnets = [] - motion_modules = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels @@ -364,20 +360,18 @@ def __init__( pre_norm=resnet_pre_norm, ) ) - motion_modules.append( - MotionBlock( - in_channels=out_channels, - norm_num_groups=motion_norm_num_groups, - cross_attention_dim=motion_cross_attention_dim, - activation_fn=motion_activation_fn, - attention_bias=motion_attention_bias, - num_attention_heads=motion_num_attention_heads, - max_seq_length=motion_max_seq_length, - ) - ) self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) + self.motion_modules = MotionModules( + in_channels=out_channels, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=motion_cross_attention_dim, + activation_fn=motion_activation_fn, + attention_bias=motion_attention_bias, + num_attention_heads=motion_num_attention_heads, + max_seq_length=motion_max_seq_length, + layers_per_block=num_layers, + ).motion_modules if add_downsample: self.downsamplers = nn.ModuleList( @@ -519,6 +513,7 @@ def __init__( attention_bias=motion_attention_bias, num_attention_heads=motion_num_attention_heads, max_seq_length=motion_max_seq_length, + layers_per_block=num_layers, ).motion_modules if add_downsample: @@ -639,7 +634,6 @@ def __init__( super().__init__() resnets = [] attentions = [] - motion_modules = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads @@ -663,18 +657,6 @@ def __init__( ) ) - motion_modules.append( - MotionBlock( - in_channels=out_channels, - norm_num_groups=motion_norm_num_groups, - cross_attention_dim=motion_cross_attention_dim, - activation_fn=motion_activation_fn, - attention_bias=motion_attention_bias, - num_attention_heads=motion_num_attention_heads, - max_seq_length=motion_max_seq_length, - ) - ) - if not dual_cross_attention: attentions.append( Transformer2DModel( @@ -711,6 +693,7 @@ def __init__( attention_bias=motion_attention_bias, num_attention_heads=motion_num_attention_heads, max_seq_length=motion_max_seq_length, + layers_per_block=num_layers, ).motion_modules if add_upsample: @@ -826,7 +809,6 @@ def __init__( motion_norm_num_groups=32, motion_cross_attention_dim=None, motion_num_attention_heads=8, - motion_attention_head_dim=40, motion_attention_bias=False, motion_activation_fn="geglu", motion_max_seq_length=24, @@ -862,6 +844,7 @@ def __init__( attention_bias=motion_attention_bias, num_attention_heads=motion_num_attention_heads, max_seq_length=motion_max_seq_length, + layers_per_block=num_layers, ).motion_modules if add_upsample: diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index 3ecd8e28611b..33765d48d784 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -27,16 +27,14 @@ AttnAddedKVProcessor, AttnProcessor, ) -from .embeddings import ( - TimestepEmbedding, - Timesteps, -) +from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_3d_condition import UNet3DConditionOutput from .unet_motion_blocks import ( CrossAttnDownBlockMotion, CrossAttnUpBlockMotion, DownBlockMotion, + MotionAdapter, UNetMidBlockCrossAttnMotion, UpBlockMotion, get_down_block, @@ -105,7 +103,7 @@ def __init__( motion_norm_num_groups=32, motion_cross_attention_dim=None, motion_num_attention_heads=8, - motion_attenion_bias=False, + motion_attention_bias=False, motion_activation_fn="geglu", motion_max_seq_length=24, ): @@ -193,7 +191,7 @@ def __init__( motion_norm_num_groups=motion_norm_num_groups, motion_cross_attention_dim=motion_cross_attention_dim, motion_num_attention_heads=motion_num_attention_heads, - motion_attenion_bias=motion_attenion_bias, + motion_attention_bias=motion_attention_bias, motion_activation_fn=motion_activation_fn, motion_max_seq_length=motion_max_seq_length, ) @@ -213,7 +211,7 @@ def __init__( motion_norm_num_groups=motion_norm_num_groups, motion_cross_attention_dim=motion_cross_attention_dim, motion_num_attention_heads=motion_num_attention_heads, - motion_attenion_bias=motion_attenion_bias, + motion_attention_bias=motion_attention_bias, motion_activation_fn=motion_activation_fn, motion_max_seq_length=motion_max_seq_length, ) @@ -259,7 +257,7 @@ def __init__( motion_norm_num_groups=motion_norm_num_groups, motion_cross_attention_dim=motion_cross_attention_dim, motion_num_attention_heads=motion_num_attention_heads, - motion_attenion_bias=motion_attenion_bias, + motion_attention_bias=motion_attention_bias, motion_activation_fn=motion_activation_fn, motion_max_seq_length=motion_max_seq_length, ) @@ -282,7 +280,7 @@ def __init__( ) @classmethod - def from_unet2d(cls, unet, motion_adapter=None, **kwargs): + def from_unet2d(cls, unet, motion_adapter: Optional[MotionAdapter] = None, **kwargs): # based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459 config = unet.config config["_class_name"] = cls.__name__ @@ -312,15 +310,18 @@ def from_unet2d(cls, unet, motion_adapter=None, **kwargs): return model - def load_motion_modules(self, motion_modules): - state_dict = self.state_dict() - motion_modules_state_dict = motion_modules.state_dict() - state_dict.update(motion_modules_state_dict) - + def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]): + state_dict = motion_adapter.state_dict() self.load_state_dict(state_dict) - def save_motion_modules(self): - raise NotImplementedError + def save_motion_modules(self, output_path): + state_dict = self.state_dict() + output = {} + for k, v in state_dict.items(): + if "motion_modules" in k: + output[k] = v + + torch.save(output, output_path) @property # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors From bbb2b6cb96fc288f6aa498c3b4945535e7c4e909 Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 16 Oct 2023 22:20:45 +0530 Subject: [PATCH 04/55] clean up --- src/diffusers/models/unet_motion_blocks.py | 14 +++----- src/diffusers/models/unet_motion_model.py | 33 +++++++++++++++---- .../animatediff/pipeline_animatediff.py | 1 + 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/src/diffusers/models/unet_motion_blocks.py b/src/diffusers/models/unet_motion_blocks.py index 1bc86e26c533..2da0e58fb4d0 100644 --- a/src/diffusers/models/unet_motion_blocks.py +++ b/src/diffusers/models/unet_motion_blocks.py @@ -235,14 +235,8 @@ class MotionAdapter(ModelMixin, ConfigMixin): def __init__( self, block_out_channels=(320, 640, 1280, 1280), - down_block_types=( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ), - up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), layers_per_block=2, + mid_block_num_layers=1, num_attention_heads=8, attenion_bias=False, cross_attention_dim=None, @@ -269,7 +263,7 @@ def __init__( down_blocks = [] up_blocks = [] - for i, block_type in enumerate(down_block_types): + for i, channel in enumerate(block_out_channels): output_channel = block_out_channels[i] down_blocks.append( MotionModules( @@ -291,12 +285,12 @@ def __init__( activation_fn=activation_fn, attention_bias=attenion_bias, num_attention_heads=num_attention_heads, - layers_per_block=1, + layers_per_block=mid_block_num_layers, ) reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] - for i, block_type in enumerate(up_block_types): + for i, channel in enumerate(reversed_block_out_channels): output_channel = reversed_block_out_channels[i] up_blocks.append( MotionModules( diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index 33765d48d784..f79380605065 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -311,17 +311,36 @@ def from_unet2d(cls, unet, motion_adapter: Optional[MotionAdapter] = None, **kwa return model def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]): - state_dict = motion_adapter.state_dict() - self.load_state_dict(state_dict) + motion_state_dict = motion_adapter.state_dict() + self.load_state_dict(motion_state_dict) - def save_motion_modules(self, output_path): + def save_motion_modules( + self, + save_directory: str, + is_main_process: bool = True, + safe_serialization: bool = True, + variant: Optional[str] = None, + push_to_hub: bool = False, + **kwargs, + ): state_dict = self.state_dict() - output = {} + + # Extract all motion modules + motion_state_dict = {} for k, v in state_dict.items(): if "motion_modules" in k: - output[k] = v - - torch.save(output, output_path) + motion_state_dict[k] = v + + adapter = MotionAdapter.from_config(self.config) + adapter.load_state_dict(motion_state_dict) + adapter.save_pretrained( + save_directory=save_directory, + is_main_process=is_main_process, + safe_serialization=safe_serialization, + variant=variant, + push_to_hub=push_to_hub, + **kwargs, + ) @property # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 2ca4602545b1..c269f7d2a904 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -18,6 +18,7 @@ import numpy as np import torch + from transformers import CLIPTextModel, CLIPTokenizer from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin From 36b3a44a4cb025d2a2c9ea1bf2967825b41538ed Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 18 Oct 2023 18:52:58 +0000 Subject: [PATCH 05/55] clean up --- src/diffusers/models/embeddings.py | 10 ++++--- src/diffusers/models/unet_motion_blocks.py | 35 ++++++++++++---------- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index b3bd37fa557b..29823906aeee 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -255,13 +255,15 @@ def __init__(self, embed_dim: int, max_seq_length: int = 24): position = torch.arange(max_seq_length).unsqueeze(1) div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)) - self.pe = torch.zeros(max_seq_length, 1, embed_dim) - self.pe[:, 0, 0::2] = torch.sin(position * div_term) - self.pe[:, 0, 1::2] = torch.cos(position * div_term) + pe = torch.zeros(max_seq_length, 1, embed_dim) + pe[:, 0, 0::2] = torch.sin(position * div_term) + pe[:, 0, 1::2] = torch.cos(position * div_term) + + self.register_buffer("pe", pe) def forward(self, x): seq_length = x.shape[0] - x = x + self.pe[:seq_length].to(x.device) + x = x + self.pe[:seq_length] return x diff --git a/src/diffusers/models/unet_motion_blocks.py b/src/diffusers/models/unet_motion_blocks.py index 2da0e58fb4d0..b4e34b09b0e6 100644 --- a/src/diffusers/models/unet_motion_blocks.py +++ b/src/diffusers/models/unet_motion_blocks.py @@ -195,8 +195,7 @@ def forward(self, x, num_frames=1): x = x.permute(0, 2, 3, 1).reshape(batch_frames, height * width, channels) x = self.pos_embed(x) x = x.reshape(batch_frames, height, width, channels).permute(0, 3, 1, 2) - - x = self.temporal_transformer(x, num_frames=num_frames) + x = self.temporal_transformer(x, num_frames=num_frames).sample return x @@ -238,7 +237,7 @@ def __init__( layers_per_block=2, mid_block_num_layers=1, num_attention_heads=8, - attenion_bias=False, + attention_bias=False, cross_attention_dim=None, activation_fn="geglu", norm_num_groups=32, @@ -247,12 +246,12 @@ def __init__( """Container to store Motion Modules Args: - block_out_channels (tuple, optional): _description_. Defaults to (320, 640, 1280, 1280). - down_block_types (tuple, optional): _description_. Defaults to ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ). - up_block_types (tuple, optional): _description_. Defaults to ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"). - layers_per_block (int, optional): _description_. Defaults to 2. - num_attention_heads (int, optional): _description_. Defaults to 8. - attenion_bias (bool, optional): _description_. Defaults to False. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + attention_bias (bool, optional, defaults to False): Whether to include bias in attention layers. cross_attention_dim (_type_, optional): _description_. Defaults to None. activation_fn (str, optional): _description_. Defaults to "geglu". norm_num_groups (int, optional): _description_. Defaults to 32. @@ -271,7 +270,7 @@ def __init__( norm_num_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, - attention_bias=attenion_bias, + attention_bias=attention_bias, num_attention_heads=num_attention_heads, max_seq_length=max_seq_length, layers_per_block=layers_per_block, @@ -283,7 +282,7 @@ def __init__( norm_num_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, - attention_bias=attenion_bias, + attention_bias=attention_bias, num_attention_heads=num_attention_heads, layers_per_block=mid_block_num_layers, ) @@ -298,7 +297,7 @@ def __init__( norm_num_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, - attention_bias=attenion_bias, + attention_bias=attention_bias, num_attention_heads=num_attention_heads, max_seq_length=max_seq_length, layers_per_block=layers_per_block + 1, @@ -399,8 +398,12 @@ def custom_forward(*inputs): ) else: hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb + create_custom_forward(resnet), hidden_states, temb, scale ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, num_frames + ) + else: hidden_states = resnet(hidden_states, temb, scale=scale) hidden_states = motion_module(hidden_states, num_frames=num_frames) @@ -859,7 +862,7 @@ def forward( and getattr(self, "b2", None) ) - blocks = zip(self.resnets, self.motion_modules, self.motion_modules) + blocks = zip(self.resnets, self.motion_modules) for resnet, motion_module in blocks: # pop res hidden states @@ -1025,7 +1028,9 @@ def forward( ) -> torch.FloatTensor: lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) - for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules): + + blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) + for attn, resnet, motion_module in blocks: if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): From 2db7bd3516027da926ea4ba3798ac6d51db0f2a5 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 20 Oct 2023 07:31:06 +0000 Subject: [PATCH 06/55] clean up --- src/diffusers/__init__.py | 6 + src/diffusers/models/__init__.py | 4 +- src/diffusers/models/unet_motion_blocks.py | 104 ++++++- src/diffusers/models/unet_motion_model.py | 21 +- .../animatediff/pipeline_animatediff.py | 277 ++++++++++++++---- 5 files changed, 331 insertions(+), 81 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 42f352c029c8..7936fe08d4f3 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -79,6 +79,7 @@ "AutoencoderTiny", "ControlNetModel", "ModelMixin", + "MotionAdapter", "MultiAdapter", "PriorTransformer", "T2IAdapter", @@ -88,6 +89,7 @@ "UNet2DConditionModel", "UNet2DModel", "UNet3DConditionModel", + "UNetMotionModel", "VQModel", ] ) @@ -194,6 +196,7 @@ [ "AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline", + "AnimateDiffPipeline", "AudioLDM2Pipeline", "AudioLDM2ProjectionModel", "AudioLDM2UNet2DConditionModel", @@ -438,6 +441,7 @@ AutoencoderTiny, ControlNetModel, ModelMixin, + MotionAdapter, MultiAdapter, PriorTransformer, T2IAdapter, @@ -447,6 +451,7 @@ UNet2DConditionModel, UNet2DModel, UNet3DConditionModel, + UNetMotionModel, VQModel, ) from .optimization import ( @@ -534,6 +539,7 @@ from .pipelines import ( AltDiffusionImg2ImgPipeline, AltDiffusionPipeline, + AnimateDiffPipeline, AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index af59749a3338..f807353312d1 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -35,7 +35,7 @@ _import_structure["unet_2d"] = ["UNet2DModel"] _import_structure["unet_2d_condition"] = ["UNet2DConditionModel"] _import_structure["unet_3d_condition"] = ["UNet3DConditionModel"] - _import_structure["unet_motion_model"] = ["UNetMotionModel"] + _import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] _import_structure["vq_model"] = ["VQModel"] if is_flax_available(): @@ -61,7 +61,7 @@ from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel from .unet_3d_condition import UNet3DConditionModel - from .unet_motion_model import UNetMotionModel + from .unet_motion_model import MotionAdapter, UNetMotionModel from .vq_model import VQModel if is_flax_available(): diff --git a/src/diffusers/models/unet_motion_blocks.py b/src/diffusers/models/unet_motion_blocks.py index b4e34b09b0e6..90169dc18989 100644 --- a/src/diffusers/models/unet_motion_blocks.py +++ b/src/diffusers/models/unet_motion_blocks.py @@ -2,10 +2,12 @@ import torch import torch.nn as nn +import torch.nn.functional as F from ..configuration_utils import ConfigMixin, register_to_config from ..utils import is_torch_version from ..utils.torch_utils import apply_freeu +from .attention_processor import Attention from .dual_transformer_2d import DualTransformer2DModel from .embeddings import PositionalEmbedding from .modeling_utils import ModelMixin @@ -165,6 +167,95 @@ def get_up_block( raise ValueError(f"{up_block_type} does not exist.") +class MotionAttnProcessor(nn.Module): + r""" + Attention Processor for performing attention-related computations in the Motion Modules. + """ + + def __init__(self, in_channels, max_seq_length=24): + super().__init__() + self.pos_embed = PositionalEmbedding(embed_dim=in_channels, max_seq_length=max_seq_length) + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + scale=1.0, + ): + # Apply position embedding + hidden_states = hidden_states.permute(1, 0, 2) + hidden_states = self.pos_embed(hidden_states) + hidden_states = hidden_states.permute(1, 0, 2) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states, scale=scale) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states, scale=scale) + value = attn.to_v(encoder_hidden_states, scale=scale) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, scale=scale) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class MotionBlock(nn.Module): def __init__( self, @@ -178,7 +269,6 @@ def __init__( ) -> None: super().__init__() - self.pos_embed = PositionalEmbedding(embed_dim=in_channels, max_seq_length=max_seq_length) self.temporal_transformer = TransformerTemporalModel( in_channels=in_channels, norm_num_groups=norm_num_groups, @@ -188,14 +278,12 @@ def __init__( attention_bias=attention_bias, num_attention_heads=num_attention_heads, ) + for block in self.temporal_transformer.transformer_blocks: + block.attn1.set_processor(MotionAttnProcessor(in_channels=in_channels, max_seq_length=max_seq_length)) + block.attn2.set_processor(MotionAttnProcessor(in_channels=in_channels, max_seq_length=max_seq_length)) - def forward(self, x, num_frames=1): - batch_frames, channels, height, width = x.shape - - x = x.permute(0, 2, 3, 1).reshape(batch_frames, height * width, channels) - x = self.pos_embed(x) - x = x.reshape(batch_frames, height, width, channels).permute(0, 3, 1, 2) - x = self.temporal_transformer(x, num_frames=num_frames).sample + def forward(self, x, encoder_hidden_states=None, num_frames=1): + x = self.temporal_transformer(x, encoder_hidden_states=encoder_hidden_states, num_frames=num_frames).sample return x diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index f79380605065..20b272c40322 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -19,7 +18,7 @@ import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, logging +from ..utils import logging from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, @@ -45,19 +44,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -@dataclass -class UNet2DConditionOutput(BaseOutput): - """ - The output of [`UNet2DConditionModel`]. - - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. - """ - - sample: torch.FloatTensor = None - - class UNetMotionModel(ModelMixin, ConfigMixin): r""" A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a @@ -648,7 +634,7 @@ def forward( timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - num_frames = sample.shape[2] + batch, channels, num_frames, height, width = sample.shape timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) @@ -663,7 +649,7 @@ def forward( encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) # 2. pre-process - sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) + sample = sample.permute(0, 2, 1, 3, 4).reshape((batch * num_frames, channels, height, width)) sample = self.conv_in(sample) # 3. down @@ -676,6 +662,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, + num_frames=num_frames, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index c269f7d2a904..e34f98b58e5e 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -18,12 +18,12 @@ import numpy as np import torch - from transformers import CLIPTextModel, CLIPTokenizer from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel, UNetMotionModel -from ...models.unet_motion_blocks import MotionAdapter +from ...models.lora import adjust_lora_scale_text_encoder +from ...models.unet_motion_model import MotionAdapter from ...schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, @@ -32,7 +32,7 @@ LMSDiscreteScheduler, PNDMScheduler, ) -from ...utils import BaseOutput, logging +from ...utils import BaseOutput, logging, scale_lora_layers, unscale_lora_layers from ..pipeline_utils import DiffusionPipeline @@ -42,7 +42,7 @@ Examples: ```py >>> import torch - >>> from diffusers import TextToVideoSDPipeline + >>> from diffusers import MotionAdapter, AnimatedDiffPipeline >>> from diffusers.utils import export_to_video >>> pipe = TextToVideoSDPipeline.from_pretrained( @@ -132,48 +132,133 @@ def __init__( ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): - batch_size = len(prompt) if isinstance(prompt, list) else 1 + def encode_prompt( + self, + prompt, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_videos_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not self.use_peft_backend: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype else: - attention_mask = None + prompt_embeds_dtype = prompt_embeds.dtype - text_embeddings = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - text_embeddings = text_embeddings[0] + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) - text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." @@ -189,7 +274,11 @@ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_fr else: uncond_tokens = negative_prompt - max_length = text_input_ids.shape[-1] + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", @@ -203,23 +292,26 @@ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_fr else: attention_mask = None - uncond_embeddings = self.text_encoder( + negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) - uncond_embeddings = uncond_embeddings[0] + negative_prompt_embeds = negative_prompt_embeds[0] + if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) + seq_len = negative_prompt_embeds.shape[1] - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - return text_embeddings + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder) + + return prompt_embeds, negative_prompt_embeds def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents @@ -227,23 +319,56 @@ def decode_latents(self, latents): batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) - image = self.vae.decode(latents).sample + output_frames = [] + for frame_idx in range(latents.shape[0]): + output_frames.append(self.vae.decode(latents[frame_idx].unsqueeze(0)).sample) + + output = torch.cat(output_frames) video = ( - image[None, :] + output[None, :] .reshape( ( batch_size, num_frames, -1, ) - + image.shape[2:] + + output.shape[2:] ) .permute(0, 2, 1, 3, 4) ) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = (video / 2 + 0.5).clamp(0, 1) video = video.float() + return video + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -261,10 +386,17 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs - def check_inputs(self, prompt, height, width, callback_steps): - if not isinstance(prompt, str) and not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -276,6 +408,32 @@ def check_inputs(self, prompt, height, width, callback_steps): f" {type(callback_steps)}." ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): @@ -500,7 +658,18 @@ def __call__( callback(i, t, latents) # Post-processing - video = self.decode_latents(latents) + video_tensor = self.decode_latents(latents) + + if output_type == "latent": + return AnimateDiffPipelineOutput(frames=latents) + + if output_type == "pt": + video = video_tensor + else: + video = tensor2vid(video_tensor) + + # Offload all models + self.maybe_free_model_hooks() if not return_dict: return (video,) From 72e0fa65d9850342f0b103e34d2597ec7a164b33 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 20 Oct 2023 17:12:03 +0000 Subject: [PATCH 07/55] clean up --- .../animatediff/pipeline_animatediff.py | 62 +++++++++---------- 1 file changed, 28 insertions(+), 34 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index e34f98b58e5e..c4174988ddcf 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -45,9 +45,7 @@ >>> from diffusers import MotionAdapter, AnimatedDiffPipeline >>> from diffusers.utils import export_to_video - >>> pipe = TextToVideoSDPipeline.from_pretrained( - ... "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16" - ... ) + >>> pipe = AnimateDiffPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", motion_adapter=adapter) >>> pipe.enable_model_cpu_offload() >>> prompt = "Spiderman is surfing" @@ -59,21 +57,27 @@ def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: - # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 - # reshape to ncfhw - mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) - std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) + # Based on: + # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 + + device = video.device + + mean = torch.tensor(mean).reshape(1, -1, 1, 1, 1).to(device) + std = torch.tensor(std).reshape(1, -1, 1, 1, 1).to(device) + # unnormalize back to [0,1] video = video.mul_(std).add_(mean) video.clamp_(0, 1) - # prepare the final outputs - i, c, f, h, w = video.shape - images = video.permute(2, 3, 0, 4, 1).reshape( - f, h, i * w, c - ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) - images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames) - images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c - return images + + batch_size, channels, num_frames, height, width = video.shape + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 2, 3, 0) + images = batch_vid.unbind(dim=0) + batch_output = [(image.cpu().numpy() * 255).astype("uint8") for image in images] + outputs.append(batch_output) + + return outputs @dataclass @@ -314,29 +318,19 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + latents = latents / self.vae.config.scaling_factor output_frames = [] + + # decode frame by frame to avoid OOM for frame_idx in range(latents.shape[0]): - output_frames.append(self.vae.decode(latents[frame_idx].unsqueeze(0)).sample) + frame = self.vae.decode(latents[frame_idx].unsqueeze(0), return_dict=False)[0] + output_frames.append(frame) output = torch.cat(output_frames) - video = ( - output[None, :] - .reshape( - ( - batch_size, - num_frames, - -1, - ) - + output.shape[2:] - ) - .permute(0, 2, 1, 3, 4) - ) - video = (video / 2 + 0.5).clamp(0, 1) + video = output[None, :].permute(0, 2, 1, 3, 4) video = video.float() return video @@ -657,12 +651,12 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # Post-processing - video_tensor = self.decode_latents(latents) - if output_type == "latent": return AnimateDiffPipelineOutput(frames=latents) + # Post-processing + video_tensor = self.decode_latents(latents) + if output_type == "pt": video = video_tensor else: From d8d3515ed2a0888fba2775fee791b6f9d7822e09 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Sat, 21 Oct 2023 22:13:13 +0000 Subject: [PATCH 08/55] clean up --- src/diffusers/models/embeddings.py | 19 +++++++++++++------ src/diffusers/models/unet_motion_blocks.py | 17 +++++++---------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 29823906aeee..4abd9b362777 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -250,20 +250,27 @@ def forward(self, x): class PositionalEmbedding(nn.Module): - def __init__(self, embed_dim: int, max_seq_length: int = 24): + """Apply positional information to a sequence of embeddings. + + Args: + embed_dim: (int): Dimension of the positional embedding. + max_seq_length: Maximum sequence length to apply positional embeddings + + """ + + def __init__(self, embed_dim: int, max_seq_length: int = 32): super().__init__() position = torch.arange(max_seq_length).unsqueeze(1) div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)) pe = torch.zeros(max_seq_length, 1, embed_dim) - pe[:, 0, 0::2] = torch.sin(position * div_term) - pe[:, 0, 1::2] = torch.cos(position * div_term) - + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe) def forward(self, x): - seq_length = x.shape[0] - x = x + self.pe[:seq_length] + _, seq_length, _ = x.shape + x = x + self.pe[:, :seq_length] return x diff --git a/src/diffusers/models/unet_motion_blocks.py b/src/diffusers/models/unet_motion_blocks.py index 90169dc18989..08a7eeb3fe9a 100644 --- a/src/diffusers/models/unet_motion_blocks.py +++ b/src/diffusers/models/unet_motion_blocks.py @@ -39,7 +39,7 @@ def get_down_block( motion_num_attention_heads=8, motion_attention_bias=False, motion_activation_fn="geglu", - motion_max_seq_length=24, + motion_max_seq_length=32, ): if down_block_type == "DownBlockMotion": return DownBlockMotion( @@ -114,7 +114,7 @@ def get_up_block( motion_num_attention_heads=8, motion_attention_bias=False, motion_activation_fn="geglu", - motion_max_seq_length=24, + motion_max_seq_length=32, ): if up_block_type == "UpBlockMotion": return UpBlockMotion( @@ -185,10 +185,7 @@ def __call__( temb=None, scale=1.0, ): - # Apply position embedding - hidden_states = hidden_states.permute(1, 0, 2) hidden_states = self.pos_embed(hidden_states) - hidden_states = hidden_states.permute(1, 0, 2) residual = hidden_states if attn.spatial_norm is not None: @@ -420,7 +417,7 @@ def __init__( motion_num_attention_heads=8, motion_attention_bias=False, motion_activation_fn="geglu", - motion_max_seq_length=24, + motion_max_seq_length=32, ): super().__init__() resnets = [] @@ -536,7 +533,7 @@ def __init__( motion_num_attention_heads=8, motion_attention_bias=False, motion_activation_fn="geglu", - motion_max_seq_length=24, + motion_max_seq_length=32, ): super().__init__() resnets = [] @@ -714,7 +711,7 @@ def __init__( motion_num_attention_heads=8, motion_attention_bias=False, motion_activation_fn="geglu", - motion_max_seq_length=24, + motion_max_seq_length=32, ): super().__init__() resnets = [] @@ -896,7 +893,7 @@ def __init__( motion_num_attention_heads=8, motion_attention_bias=False, motion_activation_fn="geglu", - motion_max_seq_length=24, + motion_max_seq_length=32, ): super().__init__() resnets = [] @@ -1023,7 +1020,7 @@ def __init__( motion_num_attention_heads=8, motion_attention_bias=False, motion_activation_fn="geglu", - motion_max_seq_length=24, + motion_max_seq_length=32, ): super().__init__() From 7a5fbf8e9ec89f965a0f88046e397adaf0f06fe4 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Sun, 22 Oct 2023 16:49:38 +0000 Subject: [PATCH 09/55] clean up --- src/diffusers/models/embeddings.py | 5 +- src/diffusers/models/unet_motion_blocks.py | 154 +++++++++++++++------ src/diffusers/models/unet_motion_model.py | 92 ++++++++++-- 3 files changed, 200 insertions(+), 51 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 4abd9b362777..1b6e43733ce0 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -252,6 +252,9 @@ def forward(self, x): class PositionalEmbedding(nn.Module): """Apply positional information to a sequence of embeddings. + Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to + them + Args: embed_dim: (int): Dimension of the positional embedding. max_seq_length: Maximum sequence length to apply positional embeddings @@ -263,7 +266,7 @@ def __init__(self, embed_dim: int, max_seq_length: int = 32): position = torch.arange(max_seq_length).unsqueeze(1) div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)) - pe = torch.zeros(max_seq_length, 1, embed_dim) + pe = torch.zeros(1, max_seq_length, embed_dim) pe[0, :, 0::2] = torch.sin(position * div_term) pe[0, :, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe) diff --git a/src/diffusers/models/unet_motion_blocks.py b/src/diffusers/models/unet_motion_blocks.py index 08a7eeb3fe9a..d83356f20852 100644 --- a/src/diffusers/models/unet_motion_blocks.py +++ b/src/diffusers/models/unet_motion_blocks.py @@ -168,11 +168,83 @@ def get_up_block( class MotionAttnProcessor(nn.Module): + r""" + Default processor for performing attention-related computations. + """ + + def __init__(self, in_channels, max_seq_length=32): + super().__init__() + self.pos_embed = PositionalEmbedding(embed_dim=in_channels, max_seq_length=max_seq_length) + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + scale=1.0, + ): + residual = hidden_states + hidden_states = self.pos_embed(hidden_states) + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states, scale=scale) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states, scale=scale) + value = attn.to_v(encoder_hidden_states, scale=scale) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, scale=scale) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class MotionAttnProcessor2_0(nn.Module): r""" Attention Processor for performing attention-related computations in the Motion Modules. """ - def __init__(self, in_channels, max_seq_length=24): + def __init__(self, in_channels, max_seq_length=32): super().__init__() self.pos_embed = PositionalEmbedding(embed_dim=in_channels, max_seq_length=max_seq_length) @@ -275,14 +347,18 @@ def __init__( attention_bias=attention_bias, num_attention_heads=num_attention_heads, ) + processor_cls = MotionAttnProcessor2_0 if is_torch_version(">=", "2.0.0") else MotionAttnProcessor + for block in self.temporal_transformer.transformer_blocks: - block.attn1.set_processor(MotionAttnProcessor(in_channels=in_channels, max_seq_length=max_seq_length)) - block.attn2.set_processor(MotionAttnProcessor(in_channels=in_channels, max_seq_length=max_seq_length)) + block.attn1.set_processor(processor_cls(in_channels=in_channels, max_seq_length=max_seq_length)) + block.attn2.set_processor(processor_cls(in_channels=in_channels, max_seq_length=max_seq_length)) - def forward(self, x, encoder_hidden_states=None, num_frames=1): - x = self.temporal_transformer(x, encoder_hidden_states=encoder_hidden_states, num_frames=num_frames).sample + def forward(self, hidden_states, encoder_hidden_states=None, num_frames=1): + hidden_states = self.temporal_transformer( + hidden_states, encoder_hidden_states=encoder_hidden_states, num_frames=num_frames + ).sample - return x + return hidden_states class MotionModules(nn.Module): @@ -319,26 +395,26 @@ class MotionAdapter(ModelMixin, ConfigMixin): def __init__( self, block_out_channels=(320, 640, 1280, 1280), - layers_per_block=2, - mid_block_num_layers=1, - num_attention_heads=8, - attention_bias=False, - cross_attention_dim=None, - activation_fn="geglu", - norm_num_groups=32, - max_seq_length=24, + motion_layers_per_block=2, + motion_mid_block_layers_per_block=1, + motion_num_attention_heads=8, + motion_attention_bias=False, + motion_cross_attention_dim=None, + motion_activation_fn="geglu", + motion_norm_num_groups=32, + motion_max_seq_length=32, ): - """Container to store Motion Modules + """Container to store AnimateDiff Motion Modules Args: block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each block. + The tuple of output channels for each UNet block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. num_attention_heads (`int`, *optional*): - The number of attention heads. If not defined, defaults to `attention_head_dim` + The number of heads to use in each attention layer. attention_bias (bool, optional, defaults to False): Whether to include bias in attention layers. - cross_attention_dim (_type_, optional): _description_. Defaults to None. - activation_fn (str, optional): _description_. Defaults to "geglu". + cross_attention_dim (int, optional, Defaults to None): Set in order to use cross attention. + activation_fn (str, optional, Defaults to "geglu"): Activation Function. norm_num_groups (int, optional): _description_. Defaults to 32. max_seq_length (int, optional): _description_. Defaults to 24. """ @@ -352,24 +428,24 @@ def __init__( down_blocks.append( MotionModules( in_channels=output_channel, - norm_num_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - attention_bias=attention_bias, - num_attention_heads=num_attention_heads, - max_seq_length=max_seq_length, - layers_per_block=layers_per_block, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=motion_cross_attention_dim, + activation_fn=motion_activation_fn, + attention_bias=motion_attention_bias, + num_attention_heads=motion_num_attention_heads, + max_seq_length=motion_max_seq_length, + layers_per_block=motion_layers_per_block, ) ) self.mid_block = MotionModules( in_channels=block_out_channels[-1], - norm_num_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - attention_bias=attention_bias, - num_attention_heads=num_attention_heads, - layers_per_block=mid_block_num_layers, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=motion_cross_attention_dim, + activation_fn=motion_activation_fn, + attention_bias=motion_attention_bias, + num_attention_heads=motion_num_attention_heads, + layers_per_block=motion_mid_block_layers_per_block, ) reversed_block_out_channels = list(reversed(block_out_channels)) @@ -379,13 +455,13 @@ def __init__( up_blocks.append( MotionModules( in_channels=output_channel, - norm_num_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - attention_bias=attention_bias, - num_attention_heads=num_attention_heads, - max_seq_length=max_seq_length, - layers_per_block=layers_per_block + 1, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=motion_cross_attention_dim, + activation_fn=motion_activation_fn, + attention_bias=motion_attention_bias, + num_attention_heads=motion_num_attention_heads, + max_seq_length=motion_max_seq_length, + layers_per_block=motion_layers_per_block + 1, ) ) diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index 20b272c40322..caf6d5704685 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -28,12 +28,15 @@ ) from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin +from .unet_2d_condition import UNet2DConditionModel from .unet_3d_condition import UNet3DConditionOutput from .unet_motion_blocks import ( CrossAttnDownBlockMotion, CrossAttnUpBlockMotion, DownBlockMotion, MotionAdapter, + MotionAttnProcessor, + MotionAttnProcessor2_0, UNetMidBlockCrossAttnMotion, UpBlockMotion, get_down_block, @@ -92,6 +95,8 @@ def __init__( motion_attention_bias=False, motion_activation_fn="geglu", motion_max_seq_length=24, + motion_layers_per_block=2, + motion_mid_block_layers_per_block=1, ): super().__init__() @@ -266,7 +271,15 @@ def __init__( ) @classmethod - def from_unet2d(cls, unet, motion_adapter: Optional[MotionAdapter] = None, **kwargs): + def from_unet2d( + cls, + unet: UNet2DConditionModel, + motion_adapter: Optional[MotionAdapter] = None, + load_weights: bool = True, + **kwargs, + ): + has_motion_adapter = motion_adapter is not None + # based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459 config = unet.config config["_class_name"] = cls.__name__ @@ -285,20 +298,63 @@ def from_unet2d(cls, unet, motion_adapter: Optional[MotionAdapter] = None, **kwa up_blocks.append("CrossAttnUpBlockMotion") else: up_blocks.append("UpBlockMotion") + config["up_block_types"] = up_blocks - state_dict = unet.state_dict() - if motion_adapter is not None: - state_dict.update(motion_adapter.state_dict()) + if has_motion_adapter: + config["motion_norm_num_groups"] = motion_adapter.config["motion_norm_num_groups"] + config["motion_cross_attention_dim"] = motion_adapter.config["motion_cross_attention_dim"] + config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"] + config["motion_attention_bias"] = motion_adapter.config["motion_attention_bias"] + config["motion_activation_fn"] = motion_adapter.config["motion_activation_fn"] + config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"] + config["motion_layers_per_block"] = motion_adapter.config["motion_layers_per_block"] + config["motion_mid_block_layers_per_block"] = motion_adapter.config["motion_mid_block_layers_per_block"] model = cls.from_config(config) - model.load_state_dict(state_dict, strict=False) + + if load_weights: + model.conv_in.load_state_dict(unet.conv_in.state_dict()) + model.time_proj.load_state_dict(unet.time_proj.state_dict()) + model.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + + for i, down_block in enumerate(unet.down_blocks): + model.down_blocks[i].resnets.load_state_dict(down_block.resnets.state_dict()) + if hasattr(model.down_blocks[i], "attentions"): + model.down_blocks[i].attentions.load_state_dict(down_block.attentions.state_dict()) + if model.down_blocks[i].downsamplers: + model.down_blocks[i].downsamplers.load_state_dict(down_block.downsamplers.state_dict()) + + for i, up_block in enumerate(unet.up_blocks): + model.up_blocks[i].resnets.load_state_dict(up_block.resnets.state_dict()) + if hasattr(model.up_blocks[i], "attentions"): + model.up_blocks[i].attentions.load_state_dict(up_block.attentions.state_dict()) + if model.up_blocks[i].upsamplers: + model.up_blocks[i].upsamplers.load_state_dict(up_block.upsamplers.state_dict()) + + model.mid_block.resnets.load_state_dict(unet.mid_block.resnets.state_dict()) + model.mid_block.attentions.load_state_dict(unet.mid_block.attentions.state_dict()) + + if unet.conv_norm_out is not None: + model.conv_norm_out.load_state_dict(unet.conv_norm_out.state_dict()) + + if unet.conv_act is not None: + model.conv_act.load_state_dict(unet.conv_act.state_dict()) + + model.conv_out.load_state_dict(unet.conv_out.state_dict()) + + if has_motion_adapter: + model.load_motion_modules(motion_adapter) return model def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]): - motion_state_dict = motion_adapter.state_dict() - self.load_state_dict(motion_state_dict) + for i, down_block in enumerate(motion_adapter.down_blocks): + self.down_blocks[i].motion_modules.load_state_dict(down_block.motion_modules.state_dict()) + for i, up_block in enumerate(motion_adapter.up_blocks): + self.up_blocks[i].motion_modules.load_state_dict(up_block.motion_modules.state_dict()) + + self.mid_block.motion_modules.load_state_dict(motion_adapter.mid_block.motion_modules.state_dict()) def save_motion_modules( self, @@ -310,14 +366,22 @@ def save_motion_modules( **kwargs, ): state_dict = self.state_dict() - # Extract all motion modules motion_state_dict = {} for k, v in state_dict.items(): - if "motion_modules" in k: + if k.contains("motion_modules"): motion_state_dict[k] = v - adapter = MotionAdapter.from_config(self.config) + adapter = MotionAdapter( + motion_norm_num_groups=self.config["motion_norm_num_groups"], + motion_cross_attention_dim=self.config["motion_cross_attention_dim"], + motion_num_attention_heads=self.config["motion_num_attention_heads"], + motion_attention_bias=self.config["motion_attention_bias"], + motion_activation_fn=self.config["motion_activation_fn"], + motion_max_seq_length=self.config["motion_max_seq_length"], + motion_layers_per_block=self.config["motion_layers_per_block"], + motion_mid_block_layers_per_block=self.config["motion_mid_block_layers_per_block"], + ) adapter.load_state_dict(motion_state_dict) adapter.save_pretrained( save_directory=save_directory, @@ -444,6 +508,13 @@ def set_attn_processor( ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "processor"): + # ignore attention processors from motion modules + if isinstance(module.processor, MotionAttnProcessor) or isinstance( + module.processor, MotionAttnProcessor2_0 + ): + return + if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor, _remove_lora=_remove_lora) @@ -496,7 +567,6 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int for module in self.children(): fn_recursive_feed_forward(module, None, 0) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. From 9eeee36d6e47cf20fe7d839cf65fa2885a242c78 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Sun, 22 Oct 2023 17:13:24 +0000 Subject: [PATCH 10/55] clean up --- src/diffusers/models/unet_motion_blocks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_motion_blocks.py b/src/diffusers/models/unet_motion_blocks.py index d83356f20852..a95b01a48071 100644 --- a/src/diffusers/models/unet_motion_blocks.py +++ b/src/diffusers/models/unet_motion_blocks.py @@ -371,7 +371,7 @@ def __init__( cross_attention_dim=None, activation_fn="geglu", norm_num_groups=32, - max_seq_length=24, + max_seq_length=32, ): super().__init__() self.motion_modules = nn.ModuleList([]) @@ -446,6 +446,7 @@ def __init__( attention_bias=motion_attention_bias, num_attention_heads=motion_num_attention_heads, layers_per_block=motion_mid_block_layers_per_block, + max_seq_length=motion_max_seq_length, ) reversed_block_out_channels = list(reversed(block_out_channels)) From 86a4d31cdb7f19d00c9faa5838233cb9161e6b27 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Sun, 22 Oct 2023 17:56:29 +0000 Subject: [PATCH 11/55] update pipeline --- .../animatediff/pipeline_animatediff.py | 31 +++++++------------ 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index c4174988ddcf..e7aa9cb43659 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -24,6 +24,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder from ...models.unet_motion_model import MotionAdapter +from ...image_processor import VaeImageProcessor from ...schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, @@ -56,25 +57,16 @@ """ -def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: +def tensor2vid(video: torch.Tensor, processor, output_type="np"): # Based on: # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 - device = video.device - - mean = torch.tensor(mean).reshape(1, -1, 1, 1, 1).to(device) - std = torch.tensor(std).reshape(1, -1, 1, 1, 1).to(device) - - # unnormalize back to [0,1] - video = video.mul_(std).add_(mean) - video.clamp_(0, 1) - batch_size, channels, num_frames, height, width = video.shape outputs = [] for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 2, 3, 0) - images = batch_vid.unbind(dim=0) - batch_output = [(image.cpu().numpy() * 255).astype("uint8") for image in images] + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = processor.postprocess(batch_vid, output_type) + outputs.append(batch_output) return outputs @@ -135,6 +127,7 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) def encode_prompt( self, @@ -331,7 +324,7 @@ def decode_latents(self, latents): output = torch.cat(output_frames) video = output[None, :].permute(0, 2, 1, 3, 4) - video = video.float() + video = video return video @@ -469,7 +462,7 @@ def prepare_latents( def __call__( self, prompt: Union[str, List[str]], - num_frames: Optional[int], + num_frames: Optional[int] = 16, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -481,7 +474,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "tensor", + output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, @@ -530,8 +523,8 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - output_type (`str`, *optional*, defaults to `"np"`): - The output format of the generated video. Choose between `torch.FloatTensor` or `np.array`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead of a plain tuple. @@ -660,7 +653,7 @@ def __call__( if output_type == "pt": video = video_tensor else: - video = tensor2vid(video_tensor) + video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) # Offload all models self.maybe_free_model_hooks() From c7ba4b8c13b74dacc802c8f22d83852d36f9b934 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 23 Oct 2023 14:09:17 +0000 Subject: [PATCH 12/55] clean up --- src/diffusers/models/unet_motion_model.py | 3 ++- .../pipelines/animatediff/pipeline_animatediff.py | 8 +++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index caf6d5704685..120b1aec9a04 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -18,6 +18,7 @@ import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import UNet2DConditionLoadersMixin from ..utils import logging from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -47,7 +48,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class UNetMotionModel(ModelMixin, ConfigMixin): +class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index e7aa9cb43659..e8bcc4ac73cf 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -20,11 +20,11 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder from ...models.unet_motion_model import MotionAdapter -from ...image_processor import VaeImageProcessor from ...schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, @@ -524,7 +524,8 @@ def __call__( Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or `np.array`. + The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or + `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead of a plain tuple. @@ -661,7 +662,4 @@ def __call__( if not return_dict: return (video,) - # Offload all models - self.maybe_free_model_hooks() - return AnimateDiffPipelineOutput(frames=video) From 6ec184ab96685edafab80690280da25a84ae65d7 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 24 Oct 2023 06:40:33 +0000 Subject: [PATCH 13/55] clean up --- src/diffusers/models/unet_motion_blocks.py | 42 ++++++++++++++++--- src/diffusers/models/unet_motion_model.py | 27 +++++++++++- .../animatediff/pipeline_animatediff.py | 36 ++++++---------- 3 files changed, 75 insertions(+), 30 deletions(-) diff --git a/src/diffusers/models/unet_motion_blocks.py b/src/diffusers/models/unet_motion_blocks.py index a95b01a48071..320891e6b9b8 100644 --- a/src/diffusers/models/unet_motion_blocks.py +++ b/src/diffusers/models/unet_motion_blocks.py @@ -347,15 +347,20 @@ def __init__( attention_bias=attention_bias, num_attention_heads=num_attention_heads, ) + self.use_cross_attention = cross_attention_dim is not None processor_cls = MotionAttnProcessor2_0 if is_torch_version(">=", "2.0.0") else MotionAttnProcessor for block in self.temporal_transformer.transformer_blocks: block.attn1.set_processor(processor_cls(in_channels=in_channels, max_seq_length=max_seq_length)) block.attn2.set_processor(processor_cls(in_channels=in_channels, max_seq_length=max_seq_length)) - def forward(self, hidden_states, encoder_hidden_states=None, num_frames=1): + def forward(self, hidden_states, encoder_hidden_states=None, num_frames=1, cross_attention_kwargs=None): + encoder_hidden_states = encoder_hidden_states if self.use_cross_attention else None hidden_states = self.temporal_transformer( - hidden_states, encoder_hidden_states=encoder_hidden_states, num_frames=num_frames + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + num_frames=num_frames, ).sample return hidden_states @@ -741,7 +746,12 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = motion_module(hidden_states, num_frames=num_frames) + hidden_states = motion_module( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + num_frames=num_frames, + ) # apply additional residuals to the output of the last pair of resnet and attention blocks if i == len(blocks) - 1 and additional_residuals is not None: @@ -939,7 +949,12 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = motion_module(hidden_states, num_frames=num_frames) + hidden_states = motion_module( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + num_frames=num_frames, + ) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -1061,6 +1076,12 @@ def custom_forward(*inputs): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + ) + else: hidden_states = resnet(hidden_states, temb, scale=scale) hidden_states = motion_module(hidden_states, num_frames=num_frames) @@ -1213,6 +1234,12 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states, + temb, + **ckpt_kwargs, + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, @@ -1228,7 +1255,12 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = motion_module(hidden_states, num_frames=num_frames) + hidden_states = motion_module( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + num_frames=num_frames, + ) hidden_states = resnet(hidden_states, temb, scale=lora_scale) return hidden_states diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index 120b1aec9a04..c8f943e09d54 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -349,6 +349,31 @@ def from_unet2d( return model + def freeze_unet2d_params(self): + """Freeze the weights of just the UNet2DConditionModel, and leave the motion modules + unfrozen for fine tuning. + """ + # Freeze everything + for param in self.parameters(): + param.requires_grad = False + + # Unfreeze Motion Modules + for down_block in self.down_blocks: + motion_modules = down_block.motion_modules + for param in motion_modules.parameters(): + param.requires_grad = True + + for up_block in self.up_blocks: + motion_modules = up_block.motion_modules + for param in motion_modules.parameters(): + param.requires_grad = True + + motion_modules = self.mid_block.motion_modules + for param in motion_modules.parameters(): + param.requires_grad = True + + return + def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]): for i, down_block in enumerate(motion_adapter.down_blocks): self.down_blocks[i].motion_modules.load_state_dict(down_block.motion_modules.state_dict()) @@ -510,7 +535,7 @@ def set_attn_processor( def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "processor"): - # ignore attention processors from motion modules + # ignore the custom attention processors from motion modules if isinstance(module.processor, MotionAttnProcessor) or isinstance( module.processor, MotionAttnProcessor2_0 ): diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index e8bcc4ac73cf..b722952b903d 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -34,6 +34,7 @@ PNDMScheduler, ) from ...utils import BaseOutput, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -43,16 +44,15 @@ Examples: ```py >>> import torch - >>> from diffusers import MotionAdapter, AnimatedDiffPipeline - >>> from diffusers.utils import export_to_video - - >>> pipe = AnimateDiffPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", motion_adapter=adapter) - >>> pipe.enable_model_cpu_offload() - - >>> prompt = "Spiderman is surfing" - >>> video_frames = pipe(prompt).frames - >>> video_path = export_to_video(video_frames) - >>> video_path + >>> from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler + >>> from diffusers.utils import export_to_gif + + >>> adapter = MotionAdapter.from_pretrained("diffusers/motion-adapter") + >>> pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter) + >>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False) + >>> output = pipe(prompt="A corgi walking in the park") + >>> frames = output.frames[0] + >>> export_to_gif(frames, "animation.gif") ``` """ @@ -436,22 +436,10 @@ def prepare_latents( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) + if latents is None: - rand_device = "cpu" if device.type == "mps" else device - - if isinstance(generator, list): - shape = shape - # shape = (1,) + shape[1:] - latents = [ - torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) - for i in range(batch_size) - ] - latents = torch.cat(latents, dim=0).to(device) - else: - latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: - if latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler From 79f402f2d662abf93a03e1b793f4182e267ef812 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 24 Oct 2023 15:54:12 +0000 Subject: [PATCH 14/55] clean up --- src/diffusers/models/unet_motion_model.py | 76 +++-- tests/models/test_models_unet_motion.py | 334 ++++++++++++++++++++++ 2 files changed, 370 insertions(+), 40 deletions(-) create mode 100644 tests/models/test_models_unet_motion.py diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index c8f943e09d54..6ee3429ffe42 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -62,9 +62,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): def __init__( self, sample_size: Optional[int] = None, - center_input_sample: bool = False, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, in_channels: int = 4, out_channels: int = 4, down_block_types: Tuple[str] = ( @@ -90,14 +87,12 @@ def __init__( attention_head_dim: Union[int, Tuple[int]] = 8, use_linear_projection: bool = False, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, - motion_norm_num_groups=32, - motion_cross_attention_dim=None, - motion_num_attention_heads=8, - motion_attention_bias=False, - motion_activation_fn="geglu", - motion_max_seq_length=24, - motion_layers_per_block=2, - motion_mid_block_layers_per_block=1, + motion_norm_num_groups: Optional[int] = 32, + motion_cross_attention_dim: Optional[int] = None, + motion_num_attention_heads: Optional[int] = 8, + motion_attention_bias: bool = False, + motion_activation_fn: str = "geglu", + motion_max_seq_length: Optional[int] = 32, ): super().__init__() @@ -314,38 +309,38 @@ def from_unet2d( model = cls.from_config(config) - if load_weights: - model.conv_in.load_state_dict(unet.conv_in.state_dict()) - model.time_proj.load_state_dict(unet.time_proj.state_dict()) - model.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + if not load_weights: + return model - for i, down_block in enumerate(unet.down_blocks): - model.down_blocks[i].resnets.load_state_dict(down_block.resnets.state_dict()) - if hasattr(model.down_blocks[i], "attentions"): - model.down_blocks[i].attentions.load_state_dict(down_block.attentions.state_dict()) - if model.down_blocks[i].downsamplers: - model.down_blocks[i].downsamplers.load_state_dict(down_block.downsamplers.state_dict()) + model.conv_in.load_state_dict(unet.conv_in.state_dict()) + model.time_proj.load_state_dict(unet.time_proj.state_dict()) + model.time_embedding.load_state_dict(unet.time_embedding.state_dict()) - for i, up_block in enumerate(unet.up_blocks): - model.up_blocks[i].resnets.load_state_dict(up_block.resnets.state_dict()) - if hasattr(model.up_blocks[i], "attentions"): - model.up_blocks[i].attentions.load_state_dict(up_block.attentions.state_dict()) - if model.up_blocks[i].upsamplers: - model.up_blocks[i].upsamplers.load_state_dict(up_block.upsamplers.state_dict()) + for i, down_block in enumerate(unet.down_blocks): + model.down_blocks[i].resnets.load_state_dict(down_block.resnets.state_dict()) + if hasattr(model.down_blocks[i], "attentions"): + model.down_blocks[i].attentions.load_state_dict(down_block.attentions.state_dict()) + if model.down_blocks[i].downsamplers: + model.down_blocks[i].downsamplers.load_state_dict(down_block.downsamplers.state_dict()) - model.mid_block.resnets.load_state_dict(unet.mid_block.resnets.state_dict()) - model.mid_block.attentions.load_state_dict(unet.mid_block.attentions.state_dict()) + for i, up_block in enumerate(unet.up_blocks): + model.up_blocks[i].resnets.load_state_dict(up_block.resnets.state_dict()) + if hasattr(model.up_blocks[i], "attentions"): + model.up_blocks[i].attentions.load_state_dict(up_block.attentions.state_dict()) + if model.up_blocks[i].upsamplers: + model.up_blocks[i].upsamplers.load_state_dict(up_block.upsamplers.state_dict()) - if unet.conv_norm_out is not None: - model.conv_norm_out.load_state_dict(unet.conv_norm_out.state_dict()) + model.mid_block.resnets.load_state_dict(unet.mid_block.resnets.state_dict()) + model.mid_block.attentions.load_state_dict(unet.mid_block.attentions.state_dict()) - if unet.conv_act is not None: - model.conv_act.load_state_dict(unet.conv_act.state_dict()) + if unet.conv_norm_out is not None: + model.conv_norm_out.load_state_dict(unet.conv_norm_out.state_dict()) + if unet.conv_act is not None: + model.conv_act.load_state_dict(unet.conv_act.state_dict()) + model.conv_out.load_state_dict(unet.conv_out.state_dict()) - model.conv_out.load_state_dict(unet.conv_out.state_dict()) - - if has_motion_adapter: - model.load_motion_modules(motion_adapter) + if has_motion_adapter: + model.load_motion_modules(motion_adapter) return model @@ -392,21 +387,22 @@ def save_motion_modules( **kwargs, ): state_dict = self.state_dict() + # Extract all motion modules motion_state_dict = {} for k, v in state_dict.items(): - if k.contains("motion_modules"): + if "motion_modules" in k: motion_state_dict[k] = v adapter = MotionAdapter( + block_out_channels=self.config["block_out_channels"], motion_norm_num_groups=self.config["motion_norm_num_groups"], motion_cross_attention_dim=self.config["motion_cross_attention_dim"], motion_num_attention_heads=self.config["motion_num_attention_heads"], motion_attention_bias=self.config["motion_attention_bias"], motion_activation_fn=self.config["motion_activation_fn"], motion_max_seq_length=self.config["motion_max_seq_length"], - motion_layers_per_block=self.config["motion_layers_per_block"], - motion_mid_block_layers_per_block=self.config["motion_mid_block_layers_per_block"], + motion_layers_per_block=self.config["layers_per_block"], ) adapter.load_state_dict(motion_state_dict) adapter.save_pretrained( diff --git a/tests/models/test_models_unet_motion.py b/tests/models/test_models_unet_motion.py new file mode 100644 index 000000000000..af4791a8d541 --- /dev/null +++ b/tests/models/test_models_unet_motion.py @@ -0,0 +1,334 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import gc +import os +import tempfile +import unittest + +import torch +import numpy as np + +from diffusers import UNetMotionModel, UNet2DConditionModel, MotionAdapter +from diffusers.utils import logging +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + torch_device, +) + +from .test_modeling_common import ModelTesterMixin, UNetTesterMixin + + +logger = logging.get_logger(__name__) + +enable_full_determinism() + + +class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): + model_class = UNetMotionModel + main_input_name = "sample" + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 4 + num_frames = 8 + sizes = (32, 32) + + noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device) + + return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + + @property + def input_shape(self): + return (4, 8, 32, 32) + + @property + def output_shape(self): + return (4, 8, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": (32, 64), + "down_block_types": ("CrossAttnDownBlockMotion", "DownBlockMotion"), + "up_block_types": ("UpBlockMotion", "CrossAttnUpBlockMotion"), + "cross_attention_dim": 32, + "attention_head_dim": 8, + "out_channels": 4, + "in_channels": 4, + "layers_per_block": 1, + "sample_size": 32, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_from_unet2d(self): + torch.manual_seed(0) + unet2d = UNet2DConditionModel() + + torch.manual_seed(1) + model = self.model_class.from_unet2d(unet2d) + model_state_dict = model.state_dict() + + for param_name, param_value in unet2d.named_parameters(): + self.assertTrue(torch.equal(model_state_dict[param_name], param_value)) + + def test_freeze_unet2d(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.freeze_unet2d_params() + + for param_name, param_value in model.named_parameters(): + if "motion_modules" not in param_name: + self.assertFalse(param_value.requires_grad) + + else: + self.assertTrue(param_value.requires_grad) + + def test_loading_motion_adapter(self): + model = self.model_class() + adapter = MotionAdapter() + model.load_motion_modules(adapter) + + for idx, down_block in enumerate(model.down_blocks): + adapter_state_dict = adapter.down_blocks[idx].motion_modules.state_dict() + for param_name, param_value in down_block.motion_modules.named_parameters(): + self.assertTrue(torch.equal(adapter_state_dict[param_name], param_value)) + + for idx, up_block in enumerate(model.up_blocks): + adapter_state_dict = adapter.up_blocks[idx].motion_modules.state_dict() + for param_name, param_value in up_block.motion_modules.named_parameters(): + self.assertTrue(torch.equal(adapter_state_dict[param_name], param_value)) + + mid_block_adapter_state_dict = adapter.mid_block.motion_modules.state_dict() + for param_name, param_value in model.mid_block.motion_modules.named_parameters(): + self.assertTrue(torch.equal(mid_block_adapter_state_dict[param_name], param_value)) + + def test_saving_motion_modules(self): + torch.manual_seed(0) + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_motion_modules(tmpdirname) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "diffusion_pytorch_model.safetensors"))) + + adapter_loaded = MotionAdapter.from_pretrained(tmpdirname) + + torch.manual_seed(0) + model_loaded = self.model_class(**init_dict) + model_loaded.load_motion_modules(adapter_loaded) + model_loaded.to(torch_device) + + with torch.no_grad(): + output = model(**inputs_dict)[0] + output_loaded = model_loaded(**inputs_dict)[0] + + assert np.abs(output - output_loaded).max() < 1e-4 + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_enable_works(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + + model.enable_xformers_memory_efficient_attention() + + assert ( + model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ + == "XFormersAttnProcessor" + ), "xformers is not enabled" + + def test_model_attention_slicing(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 8 + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + model.set_attention_slice("auto") + with torch.no_grad(): + output = model(**inputs_dict) + assert output is not None + + model.set_attention_slice("max") + with torch.no_grad(): + output = model(**inputs_dict) + assert output is not None + + model.set_attention_slice(2) + with torch.no_grad(): + output = model(**inputs_dict) + assert output is not None + + def test_gradient_checkpointing_is_applied(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model_class_copy = copy.copy(self.model_class) + + modules_with_gc_enabled = {} + + # now monkey patch the following function: + # def _set_gradient_checkpointing(self, module, value=False): + # if hasattr(module, "gradient_checkpointing"): + # module.gradient_checkpointing = value + + def _set_gradient_checkpointing_new(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + modules_with_gc_enabled[module.__class__.__name__] = True + + model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new + + model = model_class_copy(**init_dict) + model.enable_gradient_checkpointing() + + EXPECTED_SET = { + "CrossAttnUpBlockMotion", + "CrossAttnDownBlockMotion", + "UNetMidBlockCrossAttnMotion", + "UpBlockMotion", + "Transformer2DModel", + "DownBlockMotion", + } + + assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET + assert all(modules_with_gc_enabled.values()), "All modules should be enabled" + + def test_feed_forward_chunking(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict["norm_num_groups"] = 32 + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict)[0] + + model.enable_forward_chunking() + with torch.no_grad(): + output_2 = model(**inputs_dict)[0] + + self.assertEqual(output.shape, output_2.shape, "Shape doesn't match") + assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2 + + def test_pickle(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + sample = model(**inputs_dict).sample + + sample_copy = copy.copy(sample) + + assert (sample - sample_copy).abs().max() < 1e-4 + + def test_from_save_pretrained(self, expected_max_diff=5e-5): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname, safe_serialization=False) + torch.manual_seed(0) + new_model = self.model_class.from_pretrained(tmpdirname) + new_model.to(torch_device) + + with torch.no_grad(): + image = model(**inputs_dict) + if isinstance(image, dict): + image = image.to_tuple()[0] + + new_image = new_model(**inputs_dict) + + if isinstance(new_image, dict): + new_image = new_image.to_tuple()[0] + + max_diff = (image - new_image).abs().max().item() + self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes") + + def test_from_save_pretrained_variant(self, expected_max_diff=5e-5): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname, variant="fp16", safe_serialization=False) + + torch.manual_seed(0) + new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16") + # non-variant cannot be loaded + with self.assertRaises(OSError) as error_context: + self.model_class.from_pretrained(tmpdirname) + + # make sure that error message states what keys are missing + assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(error_context.exception) + + new_model.to(torch_device) + + with torch.no_grad(): + image = model(**inputs_dict) + if isinstance(image, dict): + image = image.to_tuple()[0] + + new_image = new_model(**inputs_dict) + + if isinstance(new_image, dict): + new_image = new_image.to_tuple()[0] + + max_diff = (image - new_image).abs().max().item() + self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes") + + def test_forward_with_norm_groups(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["norm_num_groups"] = 16 + init_dict["motion_norm_num_groups"] = 16 + init_dict["block_out_channels"] = (16, 32) + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.to_tuple()[0] + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") From b24f58a5f4d9c910d97651f1e7071e7ba74e1b7e Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 25 Oct 2023 10:04:37 +0000 Subject: [PATCH 15/55] add tests --- tests/pipelines/animatediff/__init__.py | 0 .../pipelines/animatediff/test_animatediff.py | 151 ++++++++++++++++++ 2 files changed, 151 insertions(+) create mode 100644 tests/pipelines/animatediff/__init__.py create mode 100644 tests/pipelines/animatediff/test_animatediff.py diff --git a/tests/pipelines/animatediff/__init__.py b/tests/pipelines/animatediff/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py new file mode 100644 index 000000000000..6a728bef375d --- /dev/null +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -0,0 +1,151 @@ +import gc +import unittest + +import torch +import numpy as np +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AnimateDiffPipeline, + AutoencoderKL, + DDIMScheduler, + MotionAdapter, + UNet2DConditionModel, + UNetMotionModel, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from diffusers.utils.testing_utils import torch_device, slow, require_torch_gpu, numpy_cosine_similarity_distance + + +class AnimateDiffPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): + pipeline_class = AnimateDiffPipeline + params = TEXT_TO_IMAGE_PARAMS + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + norm_num_groups=2, + ) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + motion_adapter = MotionAdapter(block_out_channels=(4, 8), motion_layers_per_block=2, motion_norm_num_groups=2) + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "motion_adapter": motion_adapter, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 7.5, + "output_type": "np", + } + return inputs + + def test_motion_unet_loading(self): + components = self.get_dummy_components() + pipe = AnimateDiffPipeline.from_pretrained(**components) + + assert isinstance(pipe.unet, UNetMotionModel) + + +@slow +@require_torch_gpu +class AnimateDiffPipelineSlowTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_animatediff(self): + # make sure here that pndm scheduler skips prk + adapter = MotionAdapter.from_pretrained("diffusers/motion-adapter-test") + pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter) + pipe = pipe.to(torch_device) + pipe.scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="linear", + steps_offset=1, + clip_sample=False, + ) + pipe.set_progress_bar_config(disable=None) + + prompt = "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain" + negative_prompt = "bad quality, worse quality" + + generator = torch.manual_seed(0) + output = pipe( + prompt, + negative_prompt=negative_prompt, + num_frames=16, + generator=generator, + guidance_scale=7.5, + num_inference_steps=20, + output_type="np", + ) + + image = output.images + + image_slice = image[0, -3:, -3:, -1] + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.1010, 0.0800, 0.0794, 0.0885, 0.0843, 0.0762, 0.0769, 0.0729, 0.0586]) + + assert numpy_cosine_similarity_distance(image_slice.flatten() - expected_slice.flatten()) < 1e-2 From 2688d07a6035dd99c70c901bf8e009c68214ad2b Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 25 Oct 2023 10:26:40 +0000 Subject: [PATCH 16/55] change motion block --- src/diffusers/models/attention.py | 210 ++++++++++++++++++- src/diffusers/models/transformer_temporal.py | 149 ++++++++++++- src/diffusers/models/unet_3d_blocks.py | 108 ++++++++++ src/diffusers/models/unet_motion_blocks.py | 9 +- 4 files changed, 467 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 6f5d1da6c6ae..8e7727ac5985 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -20,7 +20,7 @@ from ..utils.torch_utils import maybe_allow_in_graph from .activations import get_activation from .attention_processor import Attention -from .embeddings import CombinedTimestepLabelEmbeddings +from .embeddings import CombinedTimestepLabelEmbeddings, get_timestep_embedding from .lora import LoRACompatibleLinear @@ -275,6 +275,214 @@ def forward( return hidden_states +@maybe_allow_in_graph +class TemporalPosEmbedTransformerBlock(nn.Module): + r""" + A Transformer block that applies a timestep positional embedding before each attention layer. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + attention_type: str = "default", + max_seq_length=32, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + self.pos_embed = get_timestep_embedding(torch.arange(max_seq_length), dim)[None, :] + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ) -> torch.FloatTensor: + _, seq_length, channels = hidden_states.shape + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + # 1. Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + hidden_states = self.pos_embed[:, :seq_length] + norm_hidden_states + + attn_output = self.attn1( + hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # 3. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + hidden_states = self.pos_embed[:, :seq_length] + norm_hidden_states + attn_output = self.attn2( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [ + self.ff(hid_slice, scale=lora_scale) + for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim) + ], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states, scale=lora_scale) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states + + class FeedForward(nn.Module): r""" A feed-forward layer. diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index d002cb3315fa..d8f47856d4cb 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -19,7 +19,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .attention import BasicTransformerBlock +from .attention import BasicTransformerBlock, TemporalPosEmbedTransformerBlock from .modeling_utils import ModelMixin @@ -183,3 +183,150 @@ def forward( return (output,) return TransformerTemporalModelOutput(sample=output) + + +class TransformerTemporalMotionModel(ModelMixin, ConfigMixin): + """ + A Transformer model to learn a Motion Prior for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlock` attention should contain a bias parameter. + double_self_attention (`bool`, *optional*): + Configure if each `TransformerBlock` should contain two self-attention layers. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + double_self_attention: bool = True, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + TemporalPosEmbedTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + double_self_attention=double_self_attention, + norm_elementwise_affine=norm_elementwise_affine, + ) + for d in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + num_frames=1, + cross_attention_kwargs=None, + return_dict: bool = True, + ): + """ + The [`TransformerTemporal`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input hidden_states. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + num_frames (`int`, *optional*, defaults to 1): + The number of frames to be processed per batch. This is used to reshape the hidden states. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: + If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is + returned, otherwise a `tuple` where the first element is the sample tensor. + """ + # 1. Input + batch_frames, channel, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + + residual = hidden_states + + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) + + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states[None, None, :] + .reshape(batch_size, height, width, num_frames, channel) + .permute(0, 3, 4, 1, 2) + .contiguous() + ) + hidden_states = hidden_states.reshape(batch_frames, channel, height, width) + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return TransformerTemporalModelOutput(sample=output) diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index 180ae0dc1a81..2d050cd5ca66 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -724,3 +724,111 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_si hidden_states = upsampler(hidden_states, upsample_size) return hidden_states + + +class DownBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + motion_norm_num_groups=32, + motion_cross_attention_dim=None, + motion_num_attention_heads=8, + motion_attention_bias=False, + motion_activation_fn="geglu", + motion_max_seq_length=32, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = MotionModules( + in_channels=out_channels, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=motion_cross_attention_dim, + activation_fn=motion_activation_fn, + attention_bias=motion_attention_bias, + num_attention_heads=motion_num_attention_heads, + max_seq_length=motion_max_seq_length, + layers_per_block=num_layers, + ).motion_modules + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, scale: float = 1.0, num_frames=1): + output_states = () + + blocks = zip(self.resnets, self.motion_modules) + for resnet, motion_module in blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, scale + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, num_frames + ) + + else: + hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = motion_module(hidden_states, num_frames=num_frames) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale=scale) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states diff --git a/src/diffusers/models/unet_motion_blocks.py b/src/diffusers/models/unet_motion_blocks.py index 320891e6b9b8..b5bb2cf76aae 100644 --- a/src/diffusers/models/unet_motion_blocks.py +++ b/src/diffusers/models/unet_motion_blocks.py @@ -13,7 +13,7 @@ from .modeling_utils import ModelMixin from .resnet import Downsample2D, ResnetBlock2D, Upsample2D from .transformer_2d import Transformer2DModel -from .transformer_temporal import TransformerTemporalModel +from .transformer_temporal import TransformerTemporalModel, TransformerTemporalMotionModel def get_down_block( @@ -338,7 +338,7 @@ def __init__( ) -> None: super().__init__() - self.temporal_transformer = TransformerTemporalModel( + self.temporal_transformer = TransformerTemporalMotionModel( in_channels=in_channels, norm_num_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, @@ -348,11 +348,6 @@ def __init__( num_attention_heads=num_attention_heads, ) self.use_cross_attention = cross_attention_dim is not None - processor_cls = MotionAttnProcessor2_0 if is_torch_version(">=", "2.0.0") else MotionAttnProcessor - - for block in self.temporal_transformer.transformer_blocks: - block.attn1.set_processor(processor_cls(in_channels=in_channels, max_seq_length=max_seq_length)) - block.attn2.set_processor(processor_cls(in_channels=in_channels, max_seq_length=max_seq_length)) def forward(self, hidden_states, encoder_hidden_states=None, num_frames=1, cross_attention_kwargs=None): encoder_hidden_states = encoder_hidden_states if self.use_cross_attention else None From 0deab59ca8a388eee177f3160aa64d23fda8797a Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 25 Oct 2023 10:47:25 +0000 Subject: [PATCH 17/55] clean up --- src/diffusers/models/attention.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 8e7727ac5985..7c9c7b0d3a18 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -425,7 +425,10 @@ def forward( lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} - hidden_states = self.pos_embed[:, :seq_length] + norm_hidden_states + hidden_states = ( + self.pos_embed[:, :seq_length].to(norm_hidden_states.device).to(norm_hidden_states.dtype) + + norm_hidden_states + ) attn_output = self.attn1( hidden_states, @@ -442,7 +445,10 @@ def forward( norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) - hidden_states = self.pos_embed[:, :seq_length] + norm_hidden_states + hidden_states = ( + self.pos_embed[:, :seq_length].to(norm_hidden_states.device).to(norm_hidden_states.dtype) + + norm_hidden_states + ) attn_output = self.attn2( hidden_states, encoder_hidden_states=encoder_hidden_states, From 9c66c21bfdce204f8a4673588b4400efc6c8a843 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 25 Oct 2023 11:04:52 +0000 Subject: [PATCH 18/55] clean up --- src/diffusers/models/attention.py | 2 +- src/diffusers/models/transformer_temporal.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 7c9c7b0d3a18..33ec4a03ca15 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -20,7 +20,7 @@ from ..utils.torch_utils import maybe_allow_in_graph from .activations import get_activation from .attention_processor import Attention -from .embeddings import CombinedTimestepLabelEmbeddings, get_timestep_embedding +from .embeddings import CombinedTimestepLabelEmbeddings, get_timestep_embedding, PositionalEmbedding from .lora import LoRACompatibleLinear diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index d8f47856d4cb..70bd840d9db4 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -298,8 +298,10 @@ def forward( residual = hidden_states - hidden_states = self.norm(hidden_states) hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + hidden_states = self.norm(hidden_states) hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) hidden_states = self.proj_in(hidden_states) From 1bd65de5d8324852ac5587cab6bff90664eab64b Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 25 Oct 2023 11:28:55 +0000 Subject: [PATCH 19/55] clean up --- src/diffusers/models/attention.py | 15 ++------------- src/diffusers/models/transformer_temporal.py | 4 +--- src/diffusers/models/unet_motion_blocks.py | 5 +++++ 3 files changed, 8 insertions(+), 16 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 33ec4a03ca15..9addcbfbc077 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -338,8 +338,6 @@ def __init__( f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." ) - self.pos_embed = get_timestep_embedding(torch.arange(max_seq_length), dim)[None, :] - # Define 3 blocks. Each block has its own normalization layer. # 1. Self-Attn if self.use_ada_layer_norm: @@ -425,13 +423,8 @@ def forward( lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} - hidden_states = ( - self.pos_embed[:, :seq_length].to(norm_hidden_states.device).to(norm_hidden_states.dtype) - + norm_hidden_states - ) - attn_output = self.attn1( - hidden_states, + norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, @@ -445,12 +438,8 @@ def forward( norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) - hidden_states = ( - self.pos_embed[:, :seq_length].to(norm_hidden_states.device).to(norm_hidden_states.dtype) - + norm_hidden_states - ) attn_output = self.attn2( - hidden_states, + norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, **cross_attention_kwargs, diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index 70bd840d9db4..d8f47856d4cb 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -298,10 +298,8 @@ def forward( residual = hidden_states - hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) - hidden_states = hidden_states.permute(0, 2, 1, 3, 4) - hidden_states = self.norm(hidden_states) + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) hidden_states = self.proj_in(hidden_states) diff --git a/src/diffusers/models/unet_motion_blocks.py b/src/diffusers/models/unet_motion_blocks.py index b5bb2cf76aae..f3c67e832678 100644 --- a/src/diffusers/models/unet_motion_blocks.py +++ b/src/diffusers/models/unet_motion_blocks.py @@ -348,6 +348,11 @@ def __init__( num_attention_heads=num_attention_heads, ) self.use_cross_attention = cross_attention_dim is not None + processor_cls = MotionAttnProcessor2_0 if is_torch_version(">=", "2.0.0") else MotionAttnProcessor + + for block in self.temporal_transformer.transformer_blocks: + block.attn1.set_processor(processor_cls(in_channels=in_channels, max_seq_length=max_seq_length)) + block.attn2.set_processor(processor_cls(in_channels=in_channels, max_seq_length=max_seq_length)) def forward(self, hidden_states, encoder_hidden_states=None, num_frames=1, cross_attention_kwargs=None): encoder_hidden_states = encoder_hidden_states if self.use_cross_attention else None From 22c9f7b3e39b5a1cd92996c1885a46d73acbb12a Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 25 Oct 2023 14:19:55 +0000 Subject: [PATCH 20/55] update --- src/diffusers/models/attention.py | 11 ++++++++--- src/diffusers/models/transformer_temporal.py | 8 +++++--- src/diffusers/models/unet_motion_blocks.py | 5 ----- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 9addcbfbc077..6d01fe389849 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -276,7 +276,7 @@ def forward( @maybe_allow_in_graph -class TemporalPosEmbedTransformerBlock(nn.Module): +class TemporalPositionEmbedTransformerBlock(nn.Module): r""" A Transformer block that applies a timestep positional embedding before each attention layer. @@ -327,6 +327,8 @@ def __init__( max_seq_length=32, ): super().__init__() + + self.pos_embed = get_timestep_embedding(max_seq_length, dim)[None, :] self.only_cross_attention = only_cross_attention self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" @@ -423,8 +425,10 @@ def forward( lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + + hidden_states = self.pos_embed[:, :seq_length, :].to(norm_hidden_states.dtype) + norm_hidden_states attn_output = self.attn1( - norm_hidden_states, + hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, @@ -438,8 +442,9 @@ def forward( norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) + hidden_states = self.pos_embed[:, :seq_length, :].to(norm_hidden_states.dtype) + norm_hidden_states attn_output = self.attn2( - norm_hidden_states, + hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, **cross_attention_kwargs, diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index d8f47856d4cb..df21169f00c2 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -19,7 +19,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .attention import BasicTransformerBlock, TemporalPosEmbedTransformerBlock +from .attention import BasicTransformerBlock, TemporalPositionEmbedTransformerBlock from .modeling_utils import ModelMixin @@ -236,7 +236,7 @@ def __init__( # 3. Define transformers blocks self.transformer_blocks = nn.ModuleList( [ - TemporalPosEmbedTransformerBlock( + TemporalPositionEmbedTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, @@ -298,8 +298,10 @@ def forward( residual = hidden_states - hidden_states = self.norm(hidden_states) hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + hidden_states = self.norm(hidden_states) hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) hidden_states = self.proj_in(hidden_states) diff --git a/src/diffusers/models/unet_motion_blocks.py b/src/diffusers/models/unet_motion_blocks.py index f3c67e832678..b5bb2cf76aae 100644 --- a/src/diffusers/models/unet_motion_blocks.py +++ b/src/diffusers/models/unet_motion_blocks.py @@ -348,11 +348,6 @@ def __init__( num_attention_heads=num_attention_heads, ) self.use_cross_attention = cross_attention_dim is not None - processor_cls = MotionAttnProcessor2_0 if is_torch_version(">=", "2.0.0") else MotionAttnProcessor - - for block in self.temporal_transformer.transformer_blocks: - block.attn1.set_processor(processor_cls(in_channels=in_channels, max_seq_length=max_seq_length)) - block.attn2.set_processor(processor_cls(in_channels=in_channels, max_seq_length=max_seq_length)) def forward(self, hidden_states, encoder_hidden_states=None, num_frames=1, cross_attention_kwargs=None): encoder_hidden_states = encoder_hidden_states if self.use_cross_attention else None From 0e1f7a83f3b672602ff2a180ed5f38b05695b1e8 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 25 Oct 2023 15:00:42 +0000 Subject: [PATCH 21/55] update --- src/diffusers/models/attention.py | 2 +- src/diffusers/models/transformer_temporal.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 6d01fe389849..3266a35f4696 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -328,7 +328,7 @@ def __init__( ): super().__init__() - self.pos_embed = get_timestep_embedding(max_seq_length, dim)[None, :] + self.pos_embed = get_timestep_embedding(torch.arange(max_seq_length), dim)[None, :] self.only_cross_attention = only_cross_attention self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index df21169f00c2..c739f9250e80 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -222,6 +222,7 @@ def __init__( activation_fn: str = "geglu", norm_elementwise_affine: bool = True, double_self_attention: bool = True, + max_seq_length: int = 32, ): super().__init__() self.num_attention_heads = num_attention_heads @@ -246,6 +247,7 @@ def __init__( attention_bias=attention_bias, double_self_attention=double_self_attention, norm_elementwise_affine=norm_elementwise_affine, + max_seq_length=max_seq_length, ) for d in range(num_layers) ] From c7e1b14e4ce122a3a4dfa20624f8ec8d2914fc08 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 25 Oct 2023 15:05:13 +0000 Subject: [PATCH 22/55] update --- src/diffusers/models/attention.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 3266a35f4696..bed8d66dd7a9 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -426,7 +426,10 @@ def forward( cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} - hidden_states = self.pos_embed[:, :seq_length, :].to(norm_hidden_states.dtype) + norm_hidden_states + hidden_states = ( + self.pos_embed[:, :seq_length, :].to(norm_hidden_states.dtype).to(norm_hidden_states.device) + + norm_hidden_states + ) attn_output = self.attn1( hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, @@ -442,7 +445,10 @@ def forward( norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) - hidden_states = self.pos_embed[:, :seq_length, :].to(norm_hidden_states.dtype) + norm_hidden_states + hidden_states = ( + self.pos_embed[:, :seq_length, :].to(norm_hidden_states.dtype).to(norm_hidden_states.device) + + norm_hidden_states + ) attn_output = self.attn2( hidden_states, encoder_hidden_states=encoder_hidden_states, From ee79cf37e9f140882058fd42f22f7de75b86d4e8 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 25 Oct 2023 15:10:53 +0000 Subject: [PATCH 23/55] update --- src/diffusers/models/unet_motion_blocks.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_motion_blocks.py b/src/diffusers/models/unet_motion_blocks.py index b5bb2cf76aae..266d7f5e2a5f 100644 --- a/src/diffusers/models/unet_motion_blocks.py +++ b/src/diffusers/models/unet_motion_blocks.py @@ -338,7 +338,7 @@ def __init__( ) -> None: super().__init__() - self.temporal_transformer = TransformerTemporalMotionModel( + self.temporal_transformer = TransformerTemporalModel( in_channels=in_channels, norm_num_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, @@ -348,6 +348,11 @@ def __init__( num_attention_heads=num_attention_heads, ) self.use_cross_attention = cross_attention_dim is not None + processor_cls = MotionAttnProcessor2_0 if is_torch_version(">=", "2.0.0") else MotionAttnProcessor + + for block in self.temporal_transformer.transformer_blocks: + block.attn1.set_processor(processor_cls(in_channels=in_channels, max_seq_length=max_seq_length)) + block.attn2.set_processor(processor_cls(in_channels=in_channels, max_seq_length=max_seq_length)) def forward(self, hidden_states, encoder_hidden_states=None, num_frames=1, cross_attention_kwargs=None): encoder_hidden_states = encoder_hidden_states if self.use_cross_attention else None From fe3828a3e707e93655182ba7cbac8dd2b7fb2538 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 25 Oct 2023 18:35:28 +0000 Subject: [PATCH 24/55] update --- src/diffusers/models/attention.py | 15 +++++--------- src/diffusers/models/transformer_temporal.py | 7 ++----- src/diffusers/models/unet_motion_blocks.py | 21 ++++++++++++-------- 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index bed8d66dd7a9..bdb20056b489 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -327,13 +327,13 @@ def __init__( max_seq_length=32, ): super().__init__() - - self.pos_embed = get_timestep_embedding(torch.arange(max_seq_length), dim)[None, :] self.only_cross_attention = only_cross_attention self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.pos_embed = PositionalEmbedding(dim, max_seq_length=max_seq_length) + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: raise ValueError( f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" @@ -426,10 +426,8 @@ def forward( cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} - hidden_states = ( - self.pos_embed[:, :seq_length, :].to(norm_hidden_states.dtype).to(norm_hidden_states.device) - + norm_hidden_states - ) + hidden_states = self.pos_embed(norm_hidden_states) + # hidden_states = self.pos_embed(norm_hidden_states) attn_output = self.attn1( hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, @@ -445,10 +443,7 @@ def forward( norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) - hidden_states = ( - self.pos_embed[:, :seq_length, :].to(norm_hidden_states.dtype).to(norm_hidden_states.device) - + norm_hidden_states - ) + hidden_states = self.pos_embed(norm_hidden_states) attn_output = self.attn2( hidden_states, encoder_hidden_states=encoder_hidden_states, diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index c739f9250e80..e3d68ce2a8a2 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -247,7 +247,6 @@ def __init__( attention_bias=attention_bias, double_self_attention=double_self_attention, norm_elementwise_affine=norm_elementwise_affine, - max_seq_length=max_seq_length, ) for d in range(num_layers) ] @@ -300,11 +299,9 @@ def forward( residual = hidden_states - hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) - hidden_states = hidden_states.permute(0, 2, 1, 3, 4) - hidden_states = self.norm(hidden_states) - hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, num_frames, channel) hidden_states = self.proj_in(hidden_states) diff --git a/src/diffusers/models/unet_motion_blocks.py b/src/diffusers/models/unet_motion_blocks.py index 266d7f5e2a5f..8d3d24b5362f 100644 --- a/src/diffusers/models/unet_motion_blocks.py +++ b/src/diffusers/models/unet_motion_blocks.py @@ -9,7 +9,7 @@ from ..utils.torch_utils import apply_freeu from .attention_processor import Attention from .dual_transformer_2d import DualTransformer2DModel -from .embeddings import PositionalEmbedding +from .embeddings import PositionalEmbedding, get_timestep_embedding from .modeling_utils import ModelMixin from .resnet import Downsample2D, ResnetBlock2D, Upsample2D from .transformer_2d import Transformer2DModel @@ -257,7 +257,17 @@ def __call__( temb=None, scale=1.0, ): - hidden_states = self.pos_embed(hidden_states) + """ + pos_embed = ( + get_timestep_embedding(torch.arange(32), hidden_states.shape[-1]).unsqueeze(0).to(hidden_states.dtype) + ) + """ + pos_embed = PositionalEmbedding(embed_dim=hidden_states.shape[-1], max_seq_length=32) + pos_embed.to("cuda") + + # hidden_states = hidden_states + pos_embed[:, : hidden_states.shape[1]].to(hidden_states.device) + # hidden_states = self.pos_embed(hidden_states) + hidden_states = pos_embed(hidden_states) residual = hidden_states if attn.spatial_norm is not None: @@ -338,7 +348,7 @@ def __init__( ) -> None: super().__init__() - self.temporal_transformer = TransformerTemporalModel( + self.temporal_transformer = TransformerTemporalMotionModel( in_channels=in_channels, norm_num_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, @@ -348,11 +358,6 @@ def __init__( num_attention_heads=num_attention_heads, ) self.use_cross_attention = cross_attention_dim is not None - processor_cls = MotionAttnProcessor2_0 if is_torch_version(">=", "2.0.0") else MotionAttnProcessor - - for block in self.temporal_transformer.transformer_blocks: - block.attn1.set_processor(processor_cls(in_channels=in_channels, max_seq_length=max_seq_length)) - block.attn2.set_processor(processor_cls(in_channels=in_channels, max_seq_length=max_seq_length)) def forward(self, hidden_states, encoder_hidden_states=None, num_frames=1, cross_attention_kwargs=None): encoder_hidden_states = encoder_hidden_states if self.use_cross_attention else None From bcbc2d1507c39950d7d5e072266cf6b4ed2aede1 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 26 Oct 2023 07:58:46 +0000 Subject: [PATCH 25/55] update --- src/diffusers/models/attention.py | 13 +- src/diffusers/models/transformer_temporal.py | 2 + src/diffusers/models/unet_motion_blocks.py | 144 ++++++++++--------- 3 files changed, 83 insertions(+), 76 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index bdb20056b489..4f8276cf06b0 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -278,7 +278,7 @@ def forward( @maybe_allow_in_graph class TemporalPositionEmbedTransformerBlock(nn.Module): r""" - A Transformer block that applies a timestep positional embedding before each attention layer. + A Transformer block that applies a positional embedding before each attention layer. Parameters: dim (`int`): The number of channels in the input and output. @@ -408,8 +408,6 @@ def forward( cross_attention_kwargs: Dict[str, Any] = None, class_labels: Optional[torch.LongTensor] = None, ) -> torch.FloatTensor: - _, seq_length, channels = hidden_states.shape - # Notice that normalization is always applied before the real computation in the following blocks. # 0. Self-Attention if self.use_ada_layer_norm: @@ -426,10 +424,9 @@ def forward( cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} - hidden_states = self.pos_embed(norm_hidden_states) - # hidden_states = self.pos_embed(norm_hidden_states) + norm_hidden_states = self.pos_embed(norm_hidden_states) attn_output = self.attn1( - hidden_states, + norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, @@ -443,9 +440,9 @@ def forward( norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) - hidden_states = self.pos_embed(norm_hidden_states) + norm_hidden_states = self.pos_embed(norm_hidden_states) attn_output = self.attn2( - hidden_states, + norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, **cross_attention_kwargs, diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index e3d68ce2a8a2..ba114596ab57 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -253,6 +253,7 @@ def __init__( ) self.proj_out = nn.Linear(inner_dim, in_channels) + self.use_cross_attention = cross_attention_dim is not None def forward( self, @@ -304,6 +305,7 @@ def forward( hidden_states = hidden_states.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, num_frames, channel) hidden_states = self.proj_in(hidden_states) + encoder_hidden_states = encoder_hidden_states if self.use_cross_attention else None # 2. Blocks for block in self.transformer_blocks: diff --git a/src/diffusers/models/unet_motion_blocks.py b/src/diffusers/models/unet_motion_blocks.py index 8d3d24b5362f..1f2bb8dfe869 100644 --- a/src/diffusers/models/unet_motion_blocks.py +++ b/src/diffusers/models/unet_motion_blocks.py @@ -257,18 +257,7 @@ def __call__( temb=None, scale=1.0, ): - """ - pos_embed = ( - get_timestep_embedding(torch.arange(32), hidden_states.shape[-1]).unsqueeze(0).to(hidden_states.dtype) - ) - """ - pos_embed = PositionalEmbedding(embed_dim=hidden_states.shape[-1], max_seq_length=32) - pos_embed.to("cuda") - - # hidden_states = hidden_states + pos_embed[:, : hidden_states.shape[1]].to(hidden_states.device) - # hidden_states = self.pos_embed(hidden_states) - hidden_states = pos_embed(hidden_states) - + hidden_states = self.pos_embed(hidden_states) residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -388,7 +377,7 @@ def __init__( for i in range(layers_per_block): self.motion_modules.append( - MotionBlock( + TransformerTemporalMotionModel( in_channels=in_channels, norm_num_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, @@ -508,6 +497,7 @@ def __init__( ): super().__init__() resnets = [] + motion_modules = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels @@ -525,18 +515,20 @@ def __init__( pre_norm=resnet_pre_norm, ) ) + motion_modules.append( + TransformerTemporalMotionModel( + num_attention_heads=motion_num_attention_heads, + in_channels=out_channels, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=motion_cross_attention_dim, + attention_bias=motion_attention_bias, + activation_fn=motion_activation_fn, + max_seq_length=motion_max_seq_length, + ) + ) self.resnets = nn.ModuleList(resnets) - self.motion_modules = MotionModules( - in_channels=out_channels, - norm_num_groups=motion_norm_num_groups, - cross_attention_dim=motion_cross_attention_dim, - activation_fn=motion_activation_fn, - attention_bias=motion_attention_bias, - num_attention_heads=motion_num_attention_heads, - max_seq_length=motion_max_seq_length, - layers_per_block=num_layers, - ).motion_modules + self.motion_modules = nn.ModuleList(motion_modules) if add_downsample: self.downsamplers = nn.ModuleList( @@ -578,7 +570,7 @@ def custom_forward(*inputs): else: hidden_states = resnet(hidden_states, temb, scale=scale) - hidden_states = motion_module(hidden_states, num_frames=num_frames) + hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] output_states = output_states + (hidden_states,) @@ -625,6 +617,7 @@ def __init__( super().__init__() resnets = [] attentions = [] + motion_modules = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads @@ -672,18 +665,22 @@ def __init__( norm_num_groups=resnet_groups, ) ) + + motion_modules.append( + TransformerTemporalMotionModel( + num_attention_heads=motion_num_attention_heads, + in_channels=out_channels, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=motion_cross_attention_dim, + attention_bias=motion_attention_bias, + activation_fn=motion_activation_fn, + max_seq_length=motion_max_seq_length, + ) + ) + self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - self.motion_modules = MotionModules( - in_channels=out_channels, - norm_num_groups=motion_norm_num_groups, - cross_attention_dim=motion_cross_attention_dim, - activation_fn=motion_activation_fn, - attention_bias=motion_attention_bias, - num_attention_heads=motion_num_attention_heads, - max_seq_length=motion_max_seq_length, - layers_per_block=num_layers, - ).motion_modules + self.motion_modules = nn.ModuleList(motion_modules) if add_downsample: self.downsamplers = nn.ModuleList( @@ -756,7 +753,7 @@ def custom_forward(*inputs): encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, num_frames=num_frames, - ) + )[0] # apply additional residuals to the output of the last pair of resnet and attention blocks if i == len(blocks) - 1 and additional_residuals is not None: @@ -808,6 +805,7 @@ def __init__( super().__init__() resnets = [] attentions = [] + motion_modules = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads @@ -857,18 +855,21 @@ def __init__( norm_num_groups=resnet_groups, ) ) + motion_modules.append( + TransformerTemporalMotionModel( + num_attention_heads=motion_num_attention_heads, + in_channels=out_channels, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=motion_cross_attention_dim, + attention_bias=motion_attention_bias, + activation_fn=motion_activation_fn, + max_seq_length=motion_max_seq_length, + ) + ) + self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - self.motion_modules = MotionModules( - in_channels=out_channels, - norm_num_groups=motion_norm_num_groups, - cross_attention_dim=motion_cross_attention_dim, - activation_fn=motion_activation_fn, - attention_bias=motion_attention_bias, - num_attention_heads=motion_num_attention_heads, - max_seq_length=motion_max_seq_length, - layers_per_block=num_layers, - ).motion_modules + self.motion_modules = nn.ModuleList(motion_modules) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) @@ -959,7 +960,7 @@ def custom_forward(*inputs): encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, num_frames=num_frames, - ) + )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -994,6 +995,7 @@ def __init__( ): super().__init__() resnets = [] + motion_modules = [] for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels @@ -1014,17 +1016,20 @@ def __init__( ) ) + motion_modules.append( + TransformerTemporalMotionModel( + num_attention_heads=motion_num_attention_heads, + in_channels=out_channels, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=motion_cross_attention_dim, + attention_bias=motion_attention_bias, + activation_fn=motion_activation_fn, + max_seq_length=motion_max_seq_length, + ) + ) + self.resnets = nn.ModuleList(resnets) - self.motion_modules = MotionModules( - in_channels=out_channels, - norm_num_groups=motion_norm_num_groups, - cross_attention_dim=motion_cross_attention_dim, - activation_fn=motion_activation_fn, - attention_bias=motion_attention_bias, - num_attention_heads=motion_num_attention_heads, - max_seq_length=motion_max_seq_length, - layers_per_block=num_layers, - ).motion_modules + self.motion_modules = nn.ModuleList(motion_modules) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) @@ -1089,7 +1094,7 @@ def custom_forward(*inputs): else: hidden_states = resnet(hidden_states, temb, scale=scale) - hidden_states = motion_module(hidden_states, num_frames=num_frames) + hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -1147,6 +1152,7 @@ def __init__( ) ] attentions = [] + motion_modules = [] for _ in range(num_layers): if not dual_cross_attention: @@ -1188,19 +1194,21 @@ def __init__( pre_norm=resnet_pre_norm, ) ) + motion_modules.append( + TransformerTemporalMotionModel( + num_attention_heads=motion_num_attention_heads, + in_channels=in_channels, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=motion_cross_attention_dim, + attention_bias=motion_attention_bias, + activation_fn=motion_activation_fn, + max_seq_length=motion_max_seq_length, + ) + ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - self.motion_modules = MotionModules( - in_channels=in_channels, - norm_num_groups=motion_norm_num_groups, - cross_attention_dim=motion_cross_attention_dim, - activation_fn=motion_activation_fn, - attention_bias=motion_attention_bias, - num_attention_heads=motion_num_attention_heads, - max_seq_length=motion_max_seq_length, - layers_per_block=num_layers, - ).motion_modules + self.motion_modules = nn.ModuleList(motion_modules) self.gradient_checkpointing = False @@ -1265,7 +1273,7 @@ def custom_forward(*inputs): encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, num_frames=num_frames, - ) + )[0] hidden_states = resnet(hidden_states, temb, scale=lora_scale) return hidden_states From 4df582eef507fde48ea8a8f5a139432025e7bdbf Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 26 Oct 2023 13:53:28 +0530 Subject: [PATCH 26/55] update --- src/diffusers/models/unet_motion_blocks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/unet_motion_blocks.py b/src/diffusers/models/unet_motion_blocks.py index 1f2bb8dfe869..01fd0f06e184 100644 --- a/src/diffusers/models/unet_motion_blocks.py +++ b/src/diffusers/models/unet_motion_blocks.py @@ -384,6 +384,7 @@ def __init__( activation_fn=activation_fn, attention_bias=attention_bias, num_attention_heads=num_attention_heads, + attention_head_dim=in_channels // num_attention_heads, max_seq_length=max_seq_length, ) ) From bf5b65a024600b7af31e9f58012b49089585d9aa Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 26 Oct 2023 08:49:17 +0000 Subject: [PATCH 27/55] update --- src/diffusers/models/unet_motion_blocks.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/models/unet_motion_blocks.py b/src/diffusers/models/unet_motion_blocks.py index 01fd0f06e184..7da680f85d98 100644 --- a/src/diffusers/models/unet_motion_blocks.py +++ b/src/diffusers/models/unet_motion_blocks.py @@ -525,6 +525,7 @@ def __init__( attention_bias=motion_attention_bias, activation_fn=motion_activation_fn, max_seq_length=motion_max_seq_length, + attention_head_dim=out_channels // motion_num_attention_heads, ) ) @@ -676,6 +677,7 @@ def __init__( attention_bias=motion_attention_bias, activation_fn=motion_activation_fn, max_seq_length=motion_max_seq_length, + attention_head_dim=out_channels // motion_num_attention_heads, ) ) @@ -865,6 +867,7 @@ def __init__( attention_bias=motion_attention_bias, activation_fn=motion_activation_fn, max_seq_length=motion_max_seq_length, + attention_head_dim=out_channels // motion_num_attention_heads, ) ) @@ -1026,6 +1029,7 @@ def __init__( attention_bias=motion_attention_bias, activation_fn=motion_activation_fn, max_seq_length=motion_max_seq_length, + attention_head_dim=out_channels // motion_num_attention_heads, ) ) @@ -1198,6 +1202,7 @@ def __init__( motion_modules.append( TransformerTemporalMotionModel( num_attention_heads=motion_num_attention_heads, + attention_head_dim=in_channels // motion_num_attention_heads, in_channels=in_channels, norm_num_groups=motion_norm_num_groups, cross_attention_dim=motion_cross_attention_dim, From 3ba1ba0e18b8e70fe663fbecef8a45e5837c21fc Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 26 Oct 2023 10:22:57 +0000 Subject: [PATCH 28/55] clean up --- src/diffusers/models/embeddings.py | 8 +- src/diffusers/models/unet_3d_blocks.py | 850 ++++++++++++- src/diffusers/models/unet_motion_blocks.py | 1285 -------------------- src/diffusers/models/unet_motion_model.py | 162 ++- tests/models/test_models_unet_motion.py | 3 +- 5 files changed, 969 insertions(+), 1339 deletions(-) delete mode 100644 src/diffusers/models/unet_motion_blocks.py diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 1b6e43733ce0..ad7b316f9a6f 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -263,12 +263,8 @@ class PositionalEmbedding(nn.Module): def __init__(self, embed_dim: int, max_seq_length: int = 32): super().__init__() - position = torch.arange(max_seq_length).unsqueeze(1) - div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)) - - pe = torch.zeros(1, max_seq_length, embed_dim) - pe[0, :, 0::2] = torch.sin(position * div_term) - pe[0, :, 1::2] = torch.cos(position * div_term) + pe = get_1d_sincos_pos_embed_from_grid(embed_dim, torch.arange(32)) + pe = torch.tensor(pe, dtype=torch.float32).unsqueeze(0) self.register_buffer("pe", pe) def forward(self, x): diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index 2d050cd5ca66..31806619b2b0 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -14,11 +14,14 @@ import torch from torch import nn +from typing import Any, Dict, Optional, Tuple from ..utils.torch_utils import apply_freeu +from ..utils import is_torch_version from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D from .transformer_2d import Transformer2DModel -from .transformer_temporal import TransformerTemporalModel +from .dual_transformer_2d import DualTransformer2DModel +from .transformer_temporal import TransformerTemporalModel, TransformerTemporalMotionModel def get_down_block( @@ -39,6 +42,12 @@ def get_down_block( only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", + temporal_norm_num_groups=32, + temporal_cross_attention_dim=None, + temporal_num_attention_heads=8, + temporal_attention_bias=False, + temporal_activation_fn="geglu", + temporal_max_seq_length=32, ): if down_block_type == "DownBlock3D": return DownBlock3D( @@ -74,6 +83,53 @@ def get_down_block( upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) + if down_block_type == "DownBlockMotion": + return DownBlockMotion( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + temporal_norm_num_groups=temporal_norm_num_groups, + temporal_cross_attention_dim=temporal_cross_attention_dim, + temporal_num_attention_heads=temporal_num_attention_heads, + temporal_attention_bias=temporal_attention_bias, + temporal_activation_fn=temporal_activation_fn, + temporal_max_seq_length=temporal_max_seq_length, + ) + elif down_block_type == "CrossAttnDownBlockMotion": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion") + return CrossAttnDownBlockMotion( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + temporal_norm_num_groups=temporal_norm_num_groups, + temporal_cross_attention_dim=temporal_cross_attention_dim, + temporal_num_attention_heads=temporal_num_attention_heads, + temporal_attention_bias=temporal_attention_bias, + temporal_activation_fn=temporal_activation_fn, + temporal_max_seq_length=temporal_max_seq_length, + ) + raise ValueError(f"{down_block_type} does not exist.") @@ -96,6 +152,12 @@ def get_up_block( only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", + temporal_norm_num_groups=32, + temporal_cross_attention_dim=None, + temporal_num_attention_heads=8, + temporal_attention_bias=False, + temporal_activation_fn="geglu", + temporal_max_seq_length=32, ): if up_block_type == "UpBlock3D": return UpBlock3D( @@ -133,6 +195,54 @@ def get_up_block( resnet_time_scale_shift=resnet_time_scale_shift, resolution_idx=resolution_idx, ) + if up_block_type == "UpBlockMotion": + return UpBlockMotion( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + resolution_idx=resolution_idx, + temporal_norm_num_groups=temporal_norm_num_groups, + temporal_cross_attention_dim=temporal_cross_attention_dim, + temporal_num_attention_heads=temporal_num_attention_heads, + temporal_attention_bias=temporal_attention_bias, + temporal_activation_fn=temporal_activation_fn, + temporal_max_seq_length=temporal_max_seq_length, + ) + elif up_block_type == "CrossAttnUpBlockMotion": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion") + return CrossAttnUpBlockMotion( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resolution_idx=resolution_idx, + temporal_norm_num_groups=temporal_norm_num_groups, + temporal_cross_attention_dim=temporal_cross_attention_dim, + temporal_num_attention_heads=temporal_num_attention_heads, + temporal_attention_bias=temporal_attention_bias, + temporal_activation_fn=temporal_activation_fn, + temporal_max_seq_length=temporal_max_seq_length, + ) raise ValueError(f"{up_block_type} does not exist.") @@ -742,15 +852,16 @@ def __init__( output_scale_factor=1.0, add_downsample=True, downsample_padding=1, - motion_norm_num_groups=32, - motion_cross_attention_dim=None, - motion_num_attention_heads=8, - motion_attention_bias=False, - motion_activation_fn="geglu", - motion_max_seq_length=32, + temporal_norm_num_groups=32, + temporal_cross_attention_dim=None, + temporal_num_attention_heads=8, + temporal_attention_bias=False, + temporal_activation_fn="geglu", + temporal_max_seq_length=32, ): super().__init__() resnets = [] + motion_modules = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels @@ -768,18 +879,21 @@ def __init__( pre_norm=resnet_pre_norm, ) ) + motion_modules.append( + TransformerTemporalMotionModel( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + norm_num_groups=temporal_norm_num_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=temporal_attention_bias, + activation_fn=temporal_activation_fn, + max_seq_length=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) self.resnets = nn.ModuleList(resnets) - self.motion_modules = MotionModules( - in_channels=out_channels, - norm_num_groups=motion_norm_num_groups, - cross_attention_dim=motion_cross_attention_dim, - activation_fn=motion_activation_fn, - attention_bias=motion_attention_bias, - num_attention_heads=motion_num_attention_heads, - max_seq_length=motion_max_seq_length, - layers_per_block=num_layers, - ).motion_modules + self.motion_modules = nn.ModuleList(motion_modules) if add_downsample: self.downsamplers = nn.ModuleList( @@ -821,7 +935,7 @@ def custom_forward(*inputs): else: hidden_states = resnet(hidden_states, temb, scale=scale) - hidden_states = motion_module(hidden_states, num_frames=num_frames) + hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] output_states = output_states + (hidden_states,) @@ -832,3 +946,703 @@ def custom_forward(*inputs): output_states = output_states + (hidden_states,) return hidden_states, output_states + + +class CrossAttnDownBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + attention_type="default", + temporal_norm_num_groups=32, + temporal_cross_attention_dim=None, + temporal_num_attention_heads=8, + temporal_attention_bias=False, + temporal_activation_fn="geglu", + temporal_max_seq_length=32, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + + motion_modules.append( + TransformerTemporalMotionModel( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + norm_num_groups=temporal_norm_num_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=temporal_attention_bias, + activation_fn=temporal_activation_fn, + max_seq_length=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + num_frames=1, + encoder_attention_mask=None, + cross_attention_kwargs=None, + additional_residuals=None, + ): + output_states = () + + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) + for i, (resnet, attn, motion_module) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = motion_module( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + num_frames=num_frames, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale=lora_scale) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: int = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + attention_type="default", + temporal_norm_num_groups=32, + temporal_cross_attention_dim=None, + temporal_num_attention_heads=8, + temporal_attention_bias=False, + temporal_activation_fn="geglu", + temporal_max_seq_length=32, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + motion_modules.append( + TransformerTemporalMotionModel( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + norm_num_groups=temporal_norm_num_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=temporal_attention_bias, + activation_fn=temporal_activation_fn, + max_seq_length=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + num_frames=1, + ): + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + blocks = zip(self.resnets, self.attentions, self.motion_modules) + for resnet, attn, motion_module in blocks: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = motion_module( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + num_frames=num_frames, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) + + return hidden_states + + +class UpBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: int = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + temporal_norm_num_groups=32, + temporal_cross_attention_dim=None, + temporal_num_attention_heads=8, + temporal_attention_bias=False, + temporal_activation_fn="geglu", + temporal_max_seq_length=32, + ): + super().__init__() + resnets = [] + motion_modules = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + motion_modules.append( + TransformerTemporalMotionModel( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + norm_num_groups=temporal_norm_num_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=temporal_attention_bias, + activation_fn=temporal_activation_fn, + max_seq_length=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0, num_frames=1 + ): + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + blocks = zip(self.resnets, self.motion_modules) + + for resnet, motion_module in blocks: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + ) + + else: + hidden_states = resnet(hidden_states, temb, scale=scale) + hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size, scale=scale) + + return hidden_states + + +class UNetMidBlockCrossAttnMotion(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + attention_type="default", + temporal_norm_num_groups=32, + temporal_cross_attention_dim=None, + temporal_num_attention_heads=8, + temporal_attention_bias=False, + temporal_activation_fn="geglu", + temporal_max_seq_length=32, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + motion_modules = [] + + for _ in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + motion_modules.append( + TransformerTemporalMotionModel( + num_attention_heads=temporal_num_attention_heads, + attention_head_dim=in_channels // temporal_num_attention_heads, + in_channels=in_channels, + norm_num_groups=temporal_norm_num_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=temporal_attention_bias, + activation_fn=temporal_activation_fn, + max_seq_length=temporal_max_seq_length, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + num_frames=1, + ) -> torch.FloatTensor: + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) + + blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) + for attn, resnet, motion_module in blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = motion_module( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + num_frames=num_frames, + )[0] + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + + return hidden_states diff --git a/src/diffusers/models/unet_motion_blocks.py b/src/diffusers/models/unet_motion_blocks.py deleted file mode 100644 index 7da680f85d98..000000000000 --- a/src/diffusers/models/unet_motion_blocks.py +++ /dev/null @@ -1,1285 +0,0 @@ -from typing import Any, Dict, Optional, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import is_torch_version -from ..utils.torch_utils import apply_freeu -from .attention_processor import Attention -from .dual_transformer_2d import DualTransformer2DModel -from .embeddings import PositionalEmbedding, get_timestep_embedding -from .modeling_utils import ModelMixin -from .resnet import Downsample2D, ResnetBlock2D, Upsample2D -from .transformer_2d import Transformer2DModel -from .transformer_temporal import TransformerTemporalModel, TransformerTemporalMotionModel - - -def get_down_block( - down_block_type, - num_layers, - in_channels, - out_channels, - temb_channels, - add_downsample, - resnet_eps, - resnet_act_fn, - num_attention_heads, - resnet_groups=None, - cross_attention_dim=None, - downsample_padding=None, - dual_cross_attention=False, - use_linear_projection=True, - only_cross_attention=False, - upcast_attention=False, - resnet_time_scale_shift="default", - motion_norm_num_groups=32, - motion_cross_attention_dim=None, - motion_num_attention_heads=8, - motion_attention_bias=False, - motion_activation_fn="geglu", - motion_max_seq_length=32, -): - if down_block_type == "DownBlockMotion": - return DownBlockMotion( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - motion_norm_num_groups=motion_norm_num_groups, - motion_cross_attention_dim=motion_cross_attention_dim, - motion_num_attention_heads=motion_num_attention_heads, - motion_attention_bias=motion_attention_bias, - motion_activation_fn=motion_activation_fn, - motion_max_seq_length=motion_max_seq_length, - ) - elif down_block_type == "CrossAttnDownBlockMotion": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion") - return CrossAttnDownBlockMotion( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - motion_norm_num_groups=motion_norm_num_groups, - motion_cross_attention_dim=motion_cross_attention_dim, - motion_num_attention_heads=motion_num_attention_heads, - motion_attention_bias=motion_attention_bias, - motion_activation_fn=motion_activation_fn, - motion_max_seq_length=motion_max_seq_length, - ) - raise ValueError(f"{down_block_type} does not exist.") - - -def get_up_block( - up_block_type, - num_layers, - in_channels, - out_channels, - prev_output_channel, - temb_channels, - add_upsample, - resnet_eps, - resnet_act_fn, - num_attention_heads, - resolution_idx=None, - resnet_groups=None, - cross_attention_dim=None, - dual_cross_attention=False, - use_linear_projection=True, - only_cross_attention=False, - upcast_attention=False, - resnet_time_scale_shift="default", - motion_norm_num_groups=32, - motion_cross_attention_dim=None, - motion_num_attention_heads=8, - motion_attention_bias=False, - motion_activation_fn="geglu", - motion_max_seq_length=32, -): - if up_block_type == "UpBlockMotion": - return UpBlockMotion( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - resolution_idx=resolution_idx, - motion_norm_num_groups=motion_norm_num_groups, - motion_cross_attention_dim=motion_cross_attention_dim, - motion_num_attention_heads=motion_num_attention_heads, - motion_attention_bias=motion_attention_bias, - motion_activation_fn=motion_activation_fn, - motion_max_seq_length=motion_max_seq_length, - ) - elif up_block_type == "CrossAttnUpBlockMotion": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion") - return CrossAttnUpBlockMotion( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - resolution_idx=resolution_idx, - motion_norm_num_groups=motion_norm_num_groups, - motion_cross_attention_dim=motion_cross_attention_dim, - motion_num_attention_heads=motion_num_attention_heads, - motion_attention_bias=motion_attention_bias, - motion_activation_fn=motion_activation_fn, - motion_max_seq_length=motion_max_seq_length, - ) - raise ValueError(f"{up_block_type} does not exist.") - - -class MotionAttnProcessor(nn.Module): - r""" - Default processor for performing attention-related computations. - """ - - def __init__(self, in_channels, max_seq_length=32): - super().__init__() - self.pos_embed = PositionalEmbedding(embed_dim=in_channels, max_seq_length=max_seq_length) - - def __call__( - self, - attn: Attention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - scale=1.0, - ): - residual = hidden_states - hidden_states = self.pos_embed(hidden_states) - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states, scale=scale) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states, scale=scale) - value = attn.to_v(encoder_hidden_states, scale=scale) - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states, scale=scale) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -class MotionAttnProcessor2_0(nn.Module): - r""" - Attention Processor for performing attention-related computations in the Motion Modules. - """ - - def __init__(self, in_channels, max_seq_length=32): - super().__init__() - self.pos_embed = PositionalEmbedding(embed_dim=in_channels, max_seq_length=max_seq_length) - - def __call__( - self, - attn: Attention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - scale=1.0, - ): - hidden_states = self.pos_embed(hidden_states) - residual = hidden_states - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states, scale=scale) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states, scale=scale) - value = attn.to_v(encoder_hidden_states, scale=scale) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # linear proj - hidden_states = attn.to_out[0](hidden_states, scale=scale) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -class MotionBlock(nn.Module): - def __init__( - self, - in_channels, - norm_num_groups=32, - cross_attention_dim=None, - activation_fn="geglu", - attention_bias=False, - num_attention_heads=8, - max_seq_length=24, - ) -> None: - super().__init__() - - self.temporal_transformer = TransformerTemporalMotionModel( - in_channels=in_channels, - norm_num_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attention_head_dim=in_channels // num_attention_heads, - activation_fn=activation_fn, - attention_bias=attention_bias, - num_attention_heads=num_attention_heads, - ) - self.use_cross_attention = cross_attention_dim is not None - - def forward(self, hidden_states, encoder_hidden_states=None, num_frames=1, cross_attention_kwargs=None): - encoder_hidden_states = encoder_hidden_states if self.use_cross_attention else None - hidden_states = self.temporal_transformer( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - num_frames=num_frames, - ).sample - - return hidden_states - - -class MotionModules(nn.Module): - def __init__( - self, - in_channels, - layers_per_block=2, - num_attention_heads=8, - attention_bias=False, - cross_attention_dim=None, - activation_fn="geglu", - norm_num_groups=32, - max_seq_length=32, - ): - super().__init__() - self.motion_modules = nn.ModuleList([]) - - for i in range(layers_per_block): - self.motion_modules.append( - TransformerTemporalMotionModel( - in_channels=in_channels, - norm_num_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - attention_bias=attention_bias, - num_attention_heads=num_attention_heads, - attention_head_dim=in_channels // num_attention_heads, - max_seq_length=max_seq_length, - ) - ) - - -class MotionAdapter(ModelMixin, ConfigMixin): - @register_to_config - def __init__( - self, - block_out_channels=(320, 640, 1280, 1280), - motion_layers_per_block=2, - motion_mid_block_layers_per_block=1, - motion_num_attention_heads=8, - motion_attention_bias=False, - motion_cross_attention_dim=None, - motion_activation_fn="geglu", - motion_norm_num_groups=32, - motion_max_seq_length=32, - ): - """Container to store AnimateDiff Motion Modules - - Args: - block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each UNet block. - layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. - num_attention_heads (`int`, *optional*): - The number of heads to use in each attention layer. - attention_bias (bool, optional, defaults to False): Whether to include bias in attention layers. - cross_attention_dim (int, optional, Defaults to None): Set in order to use cross attention. - activation_fn (str, optional, Defaults to "geglu"): Activation Function. - norm_num_groups (int, optional): _description_. Defaults to 32. - max_seq_length (int, optional): _description_. Defaults to 24. - """ - - super().__init__() - down_blocks = [] - up_blocks = [] - - for i, channel in enumerate(block_out_channels): - output_channel = block_out_channels[i] - down_blocks.append( - MotionModules( - in_channels=output_channel, - norm_num_groups=motion_norm_num_groups, - cross_attention_dim=motion_cross_attention_dim, - activation_fn=motion_activation_fn, - attention_bias=motion_attention_bias, - num_attention_heads=motion_num_attention_heads, - max_seq_length=motion_max_seq_length, - layers_per_block=motion_layers_per_block, - ) - ) - - self.mid_block = MotionModules( - in_channels=block_out_channels[-1], - norm_num_groups=motion_norm_num_groups, - cross_attention_dim=motion_cross_attention_dim, - activation_fn=motion_activation_fn, - attention_bias=motion_attention_bias, - num_attention_heads=motion_num_attention_heads, - layers_per_block=motion_mid_block_layers_per_block, - max_seq_length=motion_max_seq_length, - ) - - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i, channel in enumerate(reversed_block_out_channels): - output_channel = reversed_block_out_channels[i] - up_blocks.append( - MotionModules( - in_channels=output_channel, - norm_num_groups=motion_norm_num_groups, - cross_attention_dim=motion_cross_attention_dim, - activation_fn=motion_activation_fn, - attention_bias=motion_attention_bias, - num_attention_heads=motion_num_attention_heads, - max_seq_length=motion_max_seq_length, - layers_per_block=motion_layers_per_block + 1, - ) - ) - - self.down_blocks = nn.ModuleList(down_blocks) - self.up_blocks = nn.ModuleList(up_blocks) - - def forward(self, sample): - pass - - -class DownBlockMotion(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, - motion_norm_num_groups=32, - motion_cross_attention_dim=None, - motion_num_attention_heads=8, - motion_attention_bias=False, - motion_activation_fn="geglu", - motion_max_seq_length=32, - ): - super().__init__() - resnets = [] - motion_modules = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - motion_modules.append( - TransformerTemporalMotionModel( - num_attention_heads=motion_num_attention_heads, - in_channels=out_channels, - norm_num_groups=motion_norm_num_groups, - cross_attention_dim=motion_cross_attention_dim, - attention_bias=motion_attention_bias, - activation_fn=motion_activation_fn, - max_seq_length=motion_max_seq_length, - attention_head_dim=out_channels // motion_num_attention_heads, - ) - ) - - self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward(self, hidden_states, temb=None, scale: float = 1.0, num_frames=1): - output_states = () - - blocks = zip(self.resnets, self.motion_modules) - for resnet, motion_module in blocks: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, scale - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, num_frames - ) - - else: - hidden_states = resnet(hidden_states, temb, scale=scale) - hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] - - output_states = output_states + (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, scale=scale) - - output_states = output_states + (hidden_states,) - - return hidden_states, output_states - - -class CrossAttnDownBlockMotion(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - num_attention_heads=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - downsample_padding=1, - add_downsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - attention_type="default", - motion_norm_num_groups=32, - motion_cross_attention_dim=None, - motion_num_attention_heads=8, - motion_attention_bias=False, - motion_activation_fn="geglu", - motion_max_seq_length=32, - ): - super().__init__() - resnets = [] - attentions = [] - motion_modules = [] - - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=transformer_layers_per_block, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - attention_type=attention_type, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - - motion_modules.append( - TransformerTemporalMotionModel( - num_attention_heads=motion_num_attention_heads, - in_channels=out_channels, - norm_num_groups=motion_norm_num_groups, - cross_attention_dim=motion_cross_attention_dim, - attention_bias=motion_attention_bias, - activation_fn=motion_activation_fn, - max_seq_length=motion_max_seq_length, - attention_head_dim=out_channels // motion_num_attention_heads, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - temb=None, - encoder_hidden_states=None, - attention_mask=None, - num_frames=1, - encoder_attention_mask=None, - cross_attention_kwargs=None, - additional_residuals=None, - ): - output_states = () - - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - - blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) - for i, (resnet, attn, motion_module) in enumerate(blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = motion_module( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - num_frames=num_frames, - )[0] - - # apply additional residuals to the output of the last pair of resnet and attention blocks - if i == len(blocks) - 1 and additional_residuals is not None: - hidden_states = hidden_states + additional_residuals - - output_states = output_states + (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, scale=lora_scale) - - output_states = output_states + (hidden_states,) - - return hidden_states, output_states - - -class CrossAttnUpBlockMotion(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - prev_output_channel: int, - temb_channels: int, - resolution_idx: int = None, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - num_attention_heads=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_upsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - attention_type="default", - motion_norm_num_groups=32, - motion_cross_attention_dim=None, - motion_num_attention_heads=8, - motion_attention_bias=False, - motion_activation_fn="geglu", - motion_max_seq_length=32, - ): - super().__init__() - resnets = [] - attentions = [] - motion_modules = [] - - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=transformer_layers_per_block, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - attention_type=attention_type, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - motion_modules.append( - TransformerTemporalMotionModel( - num_attention_heads=motion_num_attention_heads, - in_channels=out_channels, - norm_num_groups=motion_norm_num_groups, - cross_attention_dim=motion_cross_attention_dim, - attention_bias=motion_attention_bias, - activation_fn=motion_activation_fn, - max_seq_length=motion_max_seq_length, - attention_head_dim=out_channels // motion_num_attention_heads, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - self.resolution_idx = resolution_idx - - def forward( - self, - hidden_states: torch.FloatTensor, - res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], - temb: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - upsample_size: Optional[int] = None, - attention_mask: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - num_frames=1, - ): - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - is_freeu_enabled = ( - getattr(self, "s1", None) - and getattr(self, "s2", None) - and getattr(self, "b1", None) - and getattr(self, "b2", None) - ) - - blocks = zip(self.resnets, self.attentions, self.motion_modules) - for resnet, attn, motion_module in blocks: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - - # FreeU: Only operate on the first two stages - if is_freeu_enabled: - hidden_states, res_hidden_states = apply_freeu( - self.resolution_idx, - hidden_states, - res_hidden_states, - s1=self.s1, - s2=self.s2, - b1=self.b1, - b2=self.b2, - ) - - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - else: - hidden_states = resnet(hidden_states, temb, scale=lora_scale) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = motion_module( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - num_frames=num_frames, - )[0] - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) - - return hidden_states - - -class UpBlockMotion(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - resolution_idx: int = None, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, - motion_norm_num_groups=32, - motion_cross_attention_dim=None, - motion_num_attention_heads=8, - motion_attention_bias=False, - motion_activation_fn="geglu", - motion_max_seq_length=32, - ): - super().__init__() - resnets = [] - motion_modules = [] - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - motion_modules.append( - TransformerTemporalMotionModel( - num_attention_heads=motion_num_attention_heads, - in_channels=out_channels, - norm_num_groups=motion_norm_num_groups, - cross_attention_dim=motion_cross_attention_dim, - attention_bias=motion_attention_bias, - activation_fn=motion_activation_fn, - max_seq_length=motion_max_seq_length, - attention_head_dim=out_channels // motion_num_attention_heads, - ) - ) - - self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - self.resolution_idx = resolution_idx - - def forward( - self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0, num_frames=1 - ): - is_freeu_enabled = ( - getattr(self, "s1", None) - and getattr(self, "s2", None) - and getattr(self, "b1", None) - and getattr(self, "b2", None) - ) - - blocks = zip(self.resnets, self.motion_modules) - - for resnet, motion_module in blocks: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - - # FreeU: Only operate on the first two stages - if is_freeu_enabled: - hidden_states, res_hidden_states = apply_freeu( - self.resolution_idx, - hidden_states, - res_hidden_states, - s1=self.s1, - s2=self.s2, - b1=self.b1, - b2=self.b2, - ) - - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - ) - - else: - hidden_states = resnet(hidden_states, temb, scale=scale) - hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size, scale=scale) - - return hidden_states - - -class UNetMidBlockCrossAttnMotion(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - num_attention_heads=1, - output_scale_factor=1.0, - cross_attention_dim=1280, - dual_cross_attention=False, - use_linear_projection=False, - upcast_attention=False, - attention_type="default", - motion_norm_num_groups=32, - motion_cross_attention_dim=None, - motion_num_attention_heads=8, - motion_attention_bias=False, - motion_activation_fn="geglu", - motion_max_seq_length=32, - ): - super().__init__() - - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - motion_modules = [] - - for _ in range(num_layers): - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - num_attention_heads, - in_channels // num_attention_heads, - in_channels=in_channels, - num_layers=transformer_layers_per_block, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - attention_type=attention_type, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - num_attention_heads, - in_channels // num_attention_heads, - in_channels=in_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - motion_modules.append( - TransformerTemporalMotionModel( - num_attention_heads=motion_num_attention_heads, - attention_head_dim=in_channels // motion_num_attention_heads, - in_channels=in_channels, - norm_num_groups=motion_norm_num_groups, - cross_attention_dim=motion_cross_attention_dim, - attention_bias=motion_attention_bias, - activation_fn=motion_activation_fn, - max_seq_length=motion_max_seq_length, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.FloatTensor, - temb: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - num_frames=1, - ) -> torch.FloatTensor: - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) - - blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) - for attn, resnet, motion_module in blocks: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(motion_module), - hidden_states, - temb, - **ckpt_kwargs, - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - else: - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = motion_module( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - num_frames=num_frames, - )[0] - hidden_states = resnet(hidden_states, temb, scale=lora_scale) - - return hidden_states diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index 6ee3429ffe42..da1df09b3aee 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -31,23 +31,134 @@ from .modeling_utils import ModelMixin from .unet_2d_condition import UNet2DConditionModel from .unet_3d_condition import UNet3DConditionOutput -from .unet_motion_blocks import ( +from .unet_3d_blocks import ( CrossAttnDownBlockMotion, CrossAttnUpBlockMotion, DownBlockMotion, - MotionAdapter, - MotionAttnProcessor, - MotionAttnProcessor2_0, UNetMidBlockCrossAttnMotion, UpBlockMotion, get_down_block, get_up_block, ) +from .transformer_temporal import TransformerTemporalMotionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class MotionModules(nn.Module): + def __init__( + self, + in_channels, + layers_per_block=2, + num_attention_heads=8, + attention_bias=False, + cross_attention_dim=None, + activation_fn="geglu", + norm_num_groups=32, + max_seq_length=32, + ): + super().__init__() + self.motion_modules = nn.ModuleList([]) + + for i in range(layers_per_block): + self.motion_modules.append( + TransformerTemporalMotionModel( + in_channels=in_channels, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + num_attention_heads=num_attention_heads, + attention_head_dim=in_channels // num_attention_heads, + max_seq_length=max_seq_length, + ) + ) + + +class MotionAdapter(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + block_out_channels=(320, 640, 1280, 1280), + motion_layers_per_block=2, + motion_mid_block_layers_per_block=1, + motion_num_attention_heads=8, + motion_attention_bias=False, + motion_cross_attention_dim=None, + motion_activation_fn="geglu", + motion_norm_num_groups=32, + motion_max_seq_length=32, + ): + """Container to store AnimateDiff Motion Modules + + Args: + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each UNet block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + num_attention_heads (`int`, *optional*): + The number of heads to use in each attention layer. + attention_bias (bool, optional, defaults to False): Whether to include bias in attention layers. + cross_attention_dim (int, optional, Defaults to None): Set in order to use cross attention. + activation_fn (str, optional, Defaults to "geglu"): Activation Function. + norm_num_groups (int, optional): _description_. Defaults to 32. + max_seq_length (int, optional): _description_. Defaults to 24. + """ + + super().__init__() + down_blocks = [] + up_blocks = [] + + for i, channel in enumerate(block_out_channels): + output_channel = block_out_channels[i] + down_blocks.append( + MotionModules( + in_channels=output_channel, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=motion_cross_attention_dim, + activation_fn=motion_activation_fn, + attention_bias=motion_attention_bias, + num_attention_heads=motion_num_attention_heads, + max_seq_length=motion_max_seq_length, + layers_per_block=motion_layers_per_block, + ) + ) + + self.mid_block = MotionModules( + in_channels=block_out_channels[-1], + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=motion_cross_attention_dim, + activation_fn=motion_activation_fn, + attention_bias=motion_attention_bias, + num_attention_heads=motion_num_attention_heads, + layers_per_block=motion_mid_block_layers_per_block, + max_seq_length=motion_max_seq_length, + ) + + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, channel in enumerate(reversed_block_out_channels): + output_channel = reversed_block_out_channels[i] + up_blocks.append( + MotionModules( + in_channels=output_channel, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=motion_cross_attention_dim, + activation_fn=motion_activation_fn, + attention_bias=motion_attention_bias, + num_attention_heads=motion_num_attention_heads, + max_seq_length=motion_max_seq_length, + layers_per_block=motion_layers_per_block + 1, + ) + ) + + self.down_blocks = nn.ModuleList(down_blocks) + self.up_blocks = nn.ModuleList(up_blocks) + + def forward(self, sample): + pass + + class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a @@ -175,12 +286,12 @@ def __init__( downsample_padding=downsample_padding, use_linear_projection=use_linear_projection, dual_cross_attention=False, - motion_norm_num_groups=motion_norm_num_groups, - motion_cross_attention_dim=motion_cross_attention_dim, - motion_num_attention_heads=motion_num_attention_heads, - motion_attention_bias=motion_attention_bias, - motion_activation_fn=motion_activation_fn, - motion_max_seq_length=motion_max_seq_length, + temporal_norm_num_groups=motion_norm_num_groups, + temporal_cross_attention_dim=motion_cross_attention_dim, + temporal_num_attention_heads=motion_num_attention_heads, + temporal_attention_bias=motion_attention_bias, + temporal_activation_fn=motion_activation_fn, + temporal_max_seq_length=motion_max_seq_length, ) self.down_blocks.append(down_block) @@ -195,12 +306,12 @@ def __init__( num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=False, - motion_norm_num_groups=motion_norm_num_groups, - motion_cross_attention_dim=motion_cross_attention_dim, - motion_num_attention_heads=motion_num_attention_heads, - motion_attention_bias=motion_attention_bias, - motion_activation_fn=motion_activation_fn, - motion_max_seq_length=motion_max_seq_length, + temporal_norm_num_groups=motion_norm_num_groups, + temporal_cross_attention_dim=motion_cross_attention_dim, + temporal_num_attention_heads=motion_num_attention_heads, + temporal_attention_bias=motion_attention_bias, + temporal_activation_fn=motion_activation_fn, + temporal_max_seq_length=motion_max_seq_length, ) # count how many layers upsample the images @@ -241,12 +352,12 @@ def __init__( dual_cross_attention=False, resolution_idx=i, use_linear_projection=use_linear_projection, - motion_norm_num_groups=motion_norm_num_groups, - motion_cross_attention_dim=motion_cross_attention_dim, - motion_num_attention_heads=motion_num_attention_heads, - motion_attention_bias=motion_attention_bias, - motion_activation_fn=motion_activation_fn, - motion_max_seq_length=motion_max_seq_length, + temporal_norm_num_groups=motion_norm_num_groups, + temporal_cross_attention_dim=motion_cross_attention_dim, + temporal_num_attention_heads=motion_num_attention_heads, + temporal_attention_bias=motion_attention_bias, + temporal_activation_fn=motion_activation_fn, + temporal_max_seq_length=motion_max_seq_length, ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -530,13 +641,6 @@ def set_attn_processor( ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "processor"): - # ignore the custom attention processors from motion modules - if isinstance(module.processor, MotionAttnProcessor) or isinstance( - module.processor, MotionAttnProcessor2_0 - ): - return - if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor, _remove_lora=_remove_lora) diff --git a/tests/models/test_models_unet_motion.py b/tests/models/test_models_unet_motion.py index af4791a8d541..47ab5bc47e5c 100644 --- a/tests/models/test_models_unet_motion.py +++ b/tests/models/test_models_unet_motion.py @@ -142,7 +142,8 @@ def test_saving_motion_modules(self): output = model(**inputs_dict)[0] output_loaded = model_loaded(**inputs_dict)[0] - assert np.abs(output - output_loaded).max() < 1e-4 + max_diff = (output - output_loaded).abs().max().item() + self.assertLessEqual(max_diff, 1e-4, "Models give different forward passes") @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), From 4d0b5ecd43a5d685fc4caf5f09dae3b05257835d Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 26 Oct 2023 11:42:06 +0000 Subject: [PATCH 29/55] update --- src/diffusers/models/attention.py | 2 +- src/diffusers/models/unet_3d_blocks.py | 7 ++++--- src/diffusers/models/unet_motion_model.py | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 4f8276cf06b0..1ce87547e5c8 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -20,7 +20,7 @@ from ..utils.torch_utils import maybe_allow_in_graph from .activations import get_activation from .attention_processor import Attention -from .embeddings import CombinedTimestepLabelEmbeddings, get_timestep_embedding, PositionalEmbedding +from .embeddings import CombinedTimestepLabelEmbeddings, PositionalEmbedding from .lora import LoRACompatibleLinear diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index 31806619b2b0..079406263d8c 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, Optional, Tuple + import torch from torch import nn -from typing import Any, Dict, Optional, Tuple -from ..utils.torch_utils import apply_freeu from ..utils import is_torch_version +from ..utils.torch_utils import apply_freeu +from .dual_transformer_2d import DualTransformer2DModel from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D from .transformer_2d import Transformer2DModel -from .dual_transformer_2d import DualTransformer2DModel from .transformer_temporal import TransformerTemporalModel, TransformerTemporalMotionModel diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index da1df09b3aee..7fc68f561427 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -29,8 +29,8 @@ ) from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin +from .transformer_temporal import TransformerTemporalMotionModel from .unet_2d_condition import UNet2DConditionModel -from .unet_3d_condition import UNet3DConditionOutput from .unet_3d_blocks import ( CrossAttnDownBlockMotion, CrossAttnUpBlockMotion, @@ -40,7 +40,7 @@ get_down_block, get_up_block, ) -from .transformer_temporal import TransformerTemporalMotionModel +from .unet_3d_condition import UNet3DConditionOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name From 8be5f1f89293a66e4d188657435894508f08675a Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 26 Oct 2023 19:14:17 +0000 Subject: [PATCH 30/55] update --- src/diffusers/models/attention.py | 221 ++----------------- src/diffusers/models/transformer_temporal.py | 183 +++------------ src/diffusers/models/unet_3d_blocks.py | 22 +- src/diffusers/models/unet_motion_model.py | 5 +- 4 files changed, 64 insertions(+), 367 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 1ce87547e5c8..76e7073f57c8 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -115,6 +115,8 @@ def __init__( norm_type: str = "layer_norm", final_dropout: bool = False, attention_type: str = "default", + use_positional_embedding: bool = False, + max_seq_length: Optional[int] = None, ): super().__init__() self.only_cross_attention = only_cross_attention @@ -128,6 +130,14 @@ def __init__( f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." ) + if use_positional_embedding and (max_seq_length is None): + raise ValueError("If `use_positional_embedding` is set to `True`, `max_sequence_length` must be defined.") + + if use_positional_embedding: + self.pos_embed = PositionalEmbedding(dim, max_seq_length=max_seq_length) + else: + self.pos_embed = None + # Define 3 blocks. Each block has its own normalization layer. # 1. Self-Attn if self.use_ada_layer_norm: @@ -207,6 +217,9 @@ def forward( else: norm_hidden_states = self.norm1(hidden_states) + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + # 1. Retrieve lora scale. lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 @@ -234,6 +247,8 @@ def forward( norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) attn_output = self.attn2( norm_hidden_states, @@ -275,212 +290,6 @@ def forward( return hidden_states -@maybe_allow_in_graph -class TemporalPositionEmbedTransformerBlock(nn.Module): - r""" - A Transformer block that applies a positional embedding before each attention layer. - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (: - obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. - only_cross_attention (`bool`, *optional*): - Whether to use only cross-attention layers. In this case two cross attention layers are used. - double_self_attention (`bool`, *optional*): - Whether to use two self-attention layers. In this case no cross attention layers are used. - upcast_attention (`bool`, *optional*): - Whether to upcast the attention computation to float32. This is useful for mixed precision training. - norm_elementwise_affine (`bool`, *optional*, defaults to `True`): - Whether to use learnable elementwise affine parameters for normalization. - norm_type (`str`, *optional*, defaults to `"layer_norm"`): - The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. - final_dropout (`bool` *optional*, defaults to False): - Whether to apply a final dropout after the last feed-forward layer. - attention_type (`str`, *optional*, defaults to `"default"`): - The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - dropout=0.0, - cross_attention_dim: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - attention_bias: bool = False, - only_cross_attention: bool = False, - double_self_attention: bool = False, - upcast_attention: bool = False, - norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", - final_dropout: bool = False, - attention_type: str = "default", - max_seq_length=32, - ): - super().__init__() - self.only_cross_attention = only_cross_attention - - self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" - self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" - - self.pos_embed = PositionalEmbedding(dim, max_seq_length=max_seq_length) - - if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: - raise ValueError( - f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" - f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." - ) - - # Define 3 blocks. Each block has its own normalization layer. - # 1. Self-Attn - if self.use_ada_layer_norm: - self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) - elif self.use_ada_layer_norm_zero: - self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) - else: - self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - ) - - # 2. Cross-Attn - if cross_attention_dim is not None or double_self_attention: - # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. - # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during - # the second cross attention block. - self.norm2 = ( - AdaLayerNorm(dim, num_embeds_ada_norm) - if self.use_ada_layer_norm - else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - ) - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim if not double_self_attention else None, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - ) # is self-attn if encoder_hidden_states is none - else: - self.norm2 = None - self.attn2 = None - - # 3. Feed-forward - self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) - - # 4. Fuser - if attention_type == "gated" or attention_type == "gated-text-image": - self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) - - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 - - def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): - # Sets chunk feed-forward - self._chunk_size = chunk_size - self._chunk_dim = dim - - def forward( - self, - hidden_states: torch.FloatTensor, - attention_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - timestep: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - class_labels: Optional[torch.LongTensor] = None, - ) -> torch.FloatTensor: - # Notice that normalization is always applied before the real computation in the following blocks. - # 0. Self-Attention - if self.use_ada_layer_norm: - norm_hidden_states = self.norm1(hidden_states, timestep) - elif self.use_ada_layer_norm_zero: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - else: - norm_hidden_states = self.norm1(hidden_states) - - # 1. Retrieve lora scale. - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - - cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} - - norm_hidden_states = self.pos_embed(norm_hidden_states) - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - if self.use_ada_layer_norm_zero: - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = attn_output + hidden_states - - # 3. Cross-Attention - if self.attn2 is not None: - norm_hidden_states = ( - self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) - ) - norm_hidden_states = self.pos_embed(norm_hidden_states) - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - **cross_attention_kwargs, - ) - hidden_states = attn_output + hidden_states - - # 4. Feed-forward - norm_hidden_states = self.norm3(hidden_states) - - if self.use_ada_layer_norm_zero: - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: - raise ValueError( - f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." - ) - - num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size - ff_output = torch.cat( - [ - self.ff(hid_slice, scale=lora_scale) - for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim) - ], - dim=self._chunk_dim, - ) - else: - ff_output = self.ff(norm_hidden_states, scale=lora_scale) - - if self.use_ada_layer_norm_zero: - ff_output = gate_mlp.unsqueeze(1) * ff_output - - hidden_states = ff_output + hidden_states - - return hidden_states - - class FeedForward(nn.Module): r""" A feed-forward layer. diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index ba114596ab57..288dcee13d39 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -19,7 +19,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .attention import BasicTransformerBlock, TemporalPositionEmbedTransformerBlock +from .attention import BasicTransformerBlock from .modeling_utils import ModelMixin @@ -55,6 +55,12 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): Configure if the `TransformerBlock` attention should contain a bias parameter. double_self_attention (`bool`, *optional*): Configure if each `TransformerBlock` should contain two self-attention layers. + use_positional_embedding (`bool`, *optional*): + Whether to apply positional embeddings before each attention layer + max_seq_length (`int`, *optional*, defaults to 32): + Maximum sequence length for positional embeddings. + apply_framewise_group_norm (`bool`, *optional*, defaults to `False`): + Whether to apply group normalization to each frame individually. """ @register_to_config @@ -63,7 +69,6 @@ def __init__( num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, - out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, @@ -73,12 +78,17 @@ def __init__( activation_fn: str = "geglu", norm_elementwise_affine: bool = True, double_self_attention: bool = True, + use_positional_embedding: bool = False, + max_seq_length: int = 32, + apply_framewise_group_norm: bool = False, ): super().__init__() self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim + self.use_cross_attention = cross_attention_dim is not None + self.apply_framewise_group_norm = apply_framewise_group_norm self.in_channels = in_channels self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) @@ -97,6 +107,8 @@ def __init__( attention_bias=attention_bias, double_self_attention=double_self_attention, norm_elementwise_affine=norm_elementwise_affine, + use_positional_embedding=use_positional_embedding, + max_seq_length=max_seq_length, ) for d in range(num_layers) ] @@ -149,160 +161,25 @@ def forward( residual = hidden_states - hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) - hidden_states = hidden_states.permute(0, 2, 1, 3, 4) - - hidden_states = self.norm(hidden_states) - hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) - - hidden_states = self.proj_in(hidden_states) - - # 2. Blocks - for block in self.transformer_blocks: - hidden_states = block( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, + # Apply group norm over batch_frames + if not self.apply_framewise_group_norm: + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 3, 4, 1, 2).reshape( + batch_size * height * width, num_frames, channel ) - # 3. Output - hidden_states = self.proj_out(hidden_states) - hidden_states = ( - hidden_states[None, None, :] - .reshape(batch_size, height, width, channel, num_frames) - .permute(0, 3, 4, 1, 2) - .contiguous() - ) - hidden_states = hidden_states.reshape(batch_frames, channel, height, width) - - output = hidden_states + residual - - if not return_dict: - return (output,) - - return TransformerTemporalModelOutput(sample=output) - - -class TransformerTemporalMotionModel(ModelMixin, ConfigMixin): - """ - A Transformer model to learn a Motion Prior for video-like data. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - The number of channels in the input and output (specify if the input is **continuous**). - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). - This is fixed during training since it is used to learn a number of position embeddings. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. - attention_bias (`bool`, *optional*): - Configure if the `TransformerBlock` attention should contain a bias parameter. - double_self_attention (`bool`, *optional*): - Configure if each `TransformerBlock` should contain two self-attention layers. - """ - - @register_to_config - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - out_channels: Optional[int] = None, - num_layers: int = 1, - dropout: float = 0.0, - norm_num_groups: int = 32, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - sample_size: Optional[int] = None, - activation_fn: str = "geglu", - norm_elementwise_affine: bool = True, - double_self_attention: bool = True, - max_seq_length: int = 32, - ): - super().__init__() - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - inner_dim = num_attention_heads * attention_head_dim + # Apply group norm after separating batch and frames + else: + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) - self.in_channels = in_channels - - self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) - self.proj_in = nn.Linear(in_channels, inner_dim) - - # 3. Define transformers blocks - self.transformer_blocks = nn.ModuleList( - [ - TemporalPositionEmbedTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - attention_bias=attention_bias, - double_self_attention=double_self_attention, - norm_elementwise_affine=norm_elementwise_affine, - ) - for d in range(num_layers) - ] - ) - - self.proj_out = nn.Linear(inner_dim, in_channels) - self.use_cross_attention = cross_attention_dim is not None - - def forward( - self, - hidden_states, - encoder_hidden_states=None, - timestep=None, - class_labels=None, - num_frames=1, - cross_attention_kwargs=None, - return_dict: bool = True, - ): - """ - The [`TransformerTemporal`] forward method. - - Args: - hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): - Input hidden_states. - encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - timestep ( `torch.long`, *optional*): - Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. - class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): - Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in - `AdaLayerZeroNorm`. - num_frames (`int`, *optional*, defaults to 1): - The number of frames to be processed per batch. This is used to reshape the hidden states. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain - tuple. - - Returns: - [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: - If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is - returned, otherwise a `tuple` where the first element is the sample tensor. - """ - # 1. Input - batch_frames, channel, height, width = hidden_states.shape - batch_size = batch_frames // num_frames - - residual = hidden_states + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape( + batch_size * height * width, num_frames, channel + ) - hidden_states = self.norm(hidden_states) - hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) - hidden_states = hidden_states.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, num_frames, channel) + torch.save(hidden_states, "hs-notframewise.pt") hidden_states = self.proj_in(hidden_states) encoder_hidden_states = encoder_hidden_states if self.use_cross_attention else None @@ -321,7 +198,7 @@ def forward( hidden_states = self.proj_out(hidden_states) hidden_states = ( hidden_states[None, None, :] - .reshape(batch_size, height, width, num_frames, channel) + .reshape(batch_size, height, width, channel, num_frames) .permute(0, 3, 4, 1, 2) .contiguous() ) diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index 079406263d8c..40cdd28ef6cb 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -22,7 +22,7 @@ from .dual_transformer_2d import DualTransformer2DModel from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D from .transformer_2d import Transformer2DModel -from .transformer_temporal import TransformerTemporalModel, TransformerTemporalMotionModel +from .transformer_temporal import TransformerTemporalModel def get_down_block( @@ -881,15 +881,17 @@ def __init__( ) ) motion_modules.append( - TransformerTemporalMotionModel( + TransformerTemporalModel( num_attention_heads=temporal_num_attention_heads, in_channels=out_channels, norm_num_groups=temporal_norm_num_groups, cross_attention_dim=temporal_cross_attention_dim, attention_bias=temporal_attention_bias, activation_fn=temporal_activation_fn, + use_positional_embedding=True, max_seq_length=temporal_max_seq_length, attention_head_dim=out_channels // temporal_num_attention_heads, + apply_framewise_group_norm=True, ) ) @@ -1033,15 +1035,17 @@ def __init__( ) motion_modules.append( - TransformerTemporalMotionModel( + TransformerTemporalModel( num_attention_heads=temporal_num_attention_heads, in_channels=out_channels, norm_num_groups=temporal_norm_num_groups, cross_attention_dim=temporal_cross_attention_dim, attention_bias=temporal_attention_bias, activation_fn=temporal_activation_fn, + use_positional_embedding=True, max_seq_length=temporal_max_seq_length, attention_head_dim=out_channels // temporal_num_attention_heads, + apply_framewise_group_norm=True, ) ) @@ -1223,15 +1227,17 @@ def __init__( ) ) motion_modules.append( - TransformerTemporalMotionModel( + TransformerTemporalModel( num_attention_heads=temporal_num_attention_heads, in_channels=out_channels, norm_num_groups=temporal_norm_num_groups, cross_attention_dim=temporal_cross_attention_dim, attention_bias=temporal_attention_bias, activation_fn=temporal_activation_fn, + use_positional_embedding=True, max_seq_length=temporal_max_seq_length, attention_head_dim=out_channels // temporal_num_attention_heads, + apply_framewise_group_norm=True, ) ) @@ -1385,15 +1391,17 @@ def __init__( ) motion_modules.append( - TransformerTemporalMotionModel( + TransformerTemporalModel( num_attention_heads=temporal_num_attention_heads, in_channels=out_channels, norm_num_groups=temporal_norm_num_groups, cross_attention_dim=temporal_cross_attention_dim, attention_bias=temporal_attention_bias, activation_fn=temporal_activation_fn, + use_positional_embedding=True, max_seq_length=temporal_max_seq_length, attention_head_dim=out_channels // temporal_num_attention_heads, + apply_framewise_group_norm=True, ) ) @@ -1564,15 +1572,17 @@ def __init__( ) ) motion_modules.append( - TransformerTemporalMotionModel( + TransformerTemporalModel( num_attention_heads=temporal_num_attention_heads, attention_head_dim=in_channels // temporal_num_attention_heads, in_channels=in_channels, norm_num_groups=temporal_norm_num_groups, cross_attention_dim=temporal_cross_attention_dim, attention_bias=temporal_attention_bias, + use_positional_embedding=True, activation_fn=temporal_activation_fn, max_seq_length=temporal_max_seq_length, + apply_framewise_group_norm=True, ) ) diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index 7fc68f561427..2c5515c2efb9 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -29,7 +29,7 @@ ) from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin -from .transformer_temporal import TransformerTemporalMotionModel +from .transformer_temporal import TransformerTemporalModel from .unet_2d_condition import UNet2DConditionModel from .unet_3d_blocks import ( CrossAttnDownBlockMotion, @@ -63,7 +63,7 @@ def __init__( for i in range(layers_per_block): self.motion_modules.append( - TransformerTemporalMotionModel( + TransformerTemporalModel( in_channels=in_channels, norm_num_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, @@ -71,6 +71,7 @@ def __init__( attention_bias=attention_bias, num_attention_heads=num_attention_heads, attention_head_dim=in_channels // num_attention_heads, + use_positional_embedding=True, max_seq_length=max_seq_length, ) ) From 313db1dd329205b3afcc5868230c26d180d067c9 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 27 Oct 2023 06:08:47 +0000 Subject: [PATCH 31/55] update model test --- tests/models/test_models_unet_motion.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/models/test_models_unet_motion.py b/tests/models/test_models_unet_motion.py index 47ab5bc47e5c..4bce4f3968d4 100644 --- a/tests/models/test_models_unet_motion.py +++ b/tests/models/test_models_unet_motion.py @@ -14,15 +14,14 @@ # limitations under the License. import copy -import gc import os import tempfile import unittest -import torch import numpy as np +import torch -from diffusers import UNetMotionModel, UNet2DConditionModel, MotionAdapter +from diffusers import MotionAdapter, UNet2DConditionModel, UNetMotionModel from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( From e82331e0c3ac85102ded10117a2bb0d3cf0028a0 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 30 Oct 2023 16:14:29 +0000 Subject: [PATCH 32/55] update --- src/diffusers/models/transformer_temporal.py | 2 +- src/diffusers/models/unet_motion_model.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index 288dcee13d39..57697f5b2dc4 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -162,7 +162,7 @@ def forward( residual = hidden_states # Apply group norm over batch_frames - if not self.apply_framewise_group_norm: + if self.apply_framewise_group_norm: hidden_states = self.norm(hidden_states) hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) hidden_states = hidden_states.permute(0, 3, 4, 1, 2).reshape( diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index 2c5515c2efb9..3d9431a4f6fc 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -205,6 +205,7 @@ def __init__( motion_attention_bias: bool = False, motion_activation_fn: str = "geglu", motion_max_seq_length: Optional[int] = 32, + motion_apply_framewise_groupnorm: bool = True, ): super().__init__() From 37de1de70f8cdf2a162fba1643d4eb9c2781ad78 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 31 Oct 2023 06:36:57 +0000 Subject: [PATCH 33/55] update --- docs/source/en/_toctree.yml | 4 ++++ src/diffusers/models/transformer_temporal.py | 22 +++----------------- 2 files changed, 7 insertions(+), 19 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index b8aa71dacbe2..8523cb57ef93 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -188,6 +188,8 @@ title: UNet2DConditionModel - local: api/models/unet3d-cond title: UNet3DConditionModel + - local: api/models/unet-motion + title: UNetMotionModel - local: api/models/vq title: VQModel - local: api/models/autoencoderkl @@ -211,6 +213,8 @@ - local: api/pipelines/alt_diffusion title: AltDiffusion - local: api/pipelines/attend_and_excite + - local: api/pipelienes/animatediff + title: AnimateDiff title: Attend-and-Excite - local: api/pipelines/audio_diffusion title: Audio Diffusion diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index 57697f5b2dc4..f115d412a9b3 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -161,25 +161,9 @@ def forward( residual = hidden_states - # Apply group norm over batch_frames - if self.apply_framewise_group_norm: - hidden_states = self.norm(hidden_states) - hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) - hidden_states = hidden_states.permute(0, 3, 4, 1, 2).reshape( - batch_size * height * width, num_frames, channel - ) - - # Apply group norm after separating batch and frames - else: - hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) - hidden_states = hidden_states.permute(0, 2, 1, 3, 4) - - hidden_states = self.norm(hidden_states) - hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape( - batch_size * height * width, num_frames, channel - ) - - torch.save(hidden_states, "hs-notframewise.pt") + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, num_frames, channel) hidden_states = self.proj_in(hidden_states) encoder_hidden_states = encoder_hidden_states if self.use_cross_attention else None From 71dc350996b1b1eb21138d0980e7a32b9bbf4e42 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 31 Oct 2023 09:33:27 +0000 Subject: [PATCH 34/55] update --- src/diffusers/models/transformer_temporal.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index f115d412a9b3..a4e8059306ad 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -80,7 +80,6 @@ def __init__( double_self_attention: bool = True, use_positional_embedding: bool = False, max_seq_length: int = 32, - apply_framewise_group_norm: bool = False, ): super().__init__() self.num_attention_heads = num_attention_heads @@ -88,7 +87,6 @@ def __init__( inner_dim = num_attention_heads * attention_head_dim self.use_cross_attention = cross_attention_dim is not None - self.apply_framewise_group_norm = apply_framewise_group_norm self.in_channels = in_channels self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) @@ -161,9 +159,11 @@ def forward( residual = hidden_states - hidden_states = self.norm(hidden_states) hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) - hidden_states = hidden_states.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, num_frames, channel) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) hidden_states = self.proj_in(hidden_states) encoder_hidden_states = encoder_hidden_states if self.use_cross_attention else None From 5d65837a460b4233b9f8254cb06b45dbd94d3328 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 31 Oct 2023 17:24:44 +0000 Subject: [PATCH 35/55] update --- src/diffusers/models/attention.py | 20 +- src/diffusers/models/embeddings.py | 2 +- src/diffusers/models/transformer_temporal.py | 15 +- src/diffusers/models/unet_3d_blocks.py | 101 +++------- src/diffusers/models/unet_3d_condition.py | 72 +------ src/diffusers/models/unet_motion_model.py | 183 ++++++------------ .../animatediff/pipeline_animatediff.py | 12 +- tests/models/test_models_unet_motion.py | 25 --- .../pipelines/animatediff/test_animatediff.py | 146 ++++++++++++-- 9 files changed, 249 insertions(+), 327 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 76e7073f57c8..5671b0a26c9e 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -20,7 +20,7 @@ from ..utils.torch_utils import maybe_allow_in_graph from .activations import get_activation from .attention_processor import Attention -from .embeddings import CombinedTimestepLabelEmbeddings, PositionalEmbedding +from .embeddings import CombinedTimestepLabelEmbeddings, SinusoidalPositionalEmbedding from .lora import LoRACompatibleLinear @@ -96,6 +96,10 @@ class BasicTransformerBlock(nn.Module): Whether to apply a final dropout after the last feed-forward layer. attention_type (`str`, *optional*, defaults to `"default"`): The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. """ def __init__( @@ -115,8 +119,8 @@ def __init__( norm_type: str = "layer_norm", final_dropout: bool = False, attention_type: str = "default", - use_positional_embedding: bool = False, - max_seq_length: Optional[int] = None, + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, ): super().__init__() self.only_cross_attention = only_cross_attention @@ -130,11 +134,13 @@ def __init__( f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." ) - if use_positional_embedding and (max_seq_length is None): - raise ValueError("If `use_positional_embedding` is set to `True`, `max_sequence_length` must be defined.") + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) - if use_positional_embedding: - self.pos_embed = PositionalEmbedding(dim, max_seq_length=max_seq_length) + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) else: self.pos_embed = None diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index ad7b316f9a6f..6b4c749202e8 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -249,7 +249,7 @@ def forward(self, x): return out -class PositionalEmbedding(nn.Module): +class SinusoidalPositionalEmbedding(nn.Module): """Apply positional information to a sequence of embeddings. Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index a4e8059306ad..c9c0b57ef113 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -55,12 +55,6 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): Configure if the `TransformerBlock` attention should contain a bias parameter. double_self_attention (`bool`, *optional*): Configure if each `TransformerBlock` should contain two self-attention layers. - use_positional_embedding (`bool`, *optional*): - Whether to apply positional embeddings before each attention layer - max_seq_length (`int`, *optional*, defaults to 32): - Maximum sequence length for positional embeddings. - apply_framewise_group_norm (`bool`, *optional*, defaults to `False`): - Whether to apply group normalization to each frame individually. """ @register_to_config @@ -74,12 +68,11 @@ def __init__( norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, - sample_size: Optional[int] = None, activation_fn: str = "geglu", norm_elementwise_affine: bool = True, double_self_attention: bool = True, - use_positional_embedding: bool = False, - max_seq_length: int = 32, + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, ): super().__init__() self.num_attention_heads = num_attention_heads @@ -105,8 +98,8 @@ def __init__( attention_bias=attention_bias, double_self_attention=double_self_attention, norm_elementwise_affine=norm_elementwise_affine, - use_positional_embedding=use_positional_embedding, - max_seq_length=max_seq_length, + positional_embeddings=positional_embeddings, + num_positional_embeddings=num_positional_embeddings, ) for d in range(num_layers) ] diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index 40cdd28ef6cb..699725a12f48 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -43,11 +43,8 @@ def get_down_block( only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", - temporal_norm_num_groups=32, - temporal_cross_attention_dim=None, temporal_num_attention_heads=8, - temporal_attention_bias=False, - temporal_activation_fn="geglu", + temporal_cross_attention_dim=None, temporal_max_seq_length=32, ): if down_block_type == "DownBlock3D": @@ -96,11 +93,8 @@ def get_down_block( resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, - temporal_norm_num_groups=temporal_norm_num_groups, - temporal_cross_attention_dim=temporal_cross_attention_dim, temporal_num_attention_heads=temporal_num_attention_heads, - temporal_attention_bias=temporal_attention_bias, - temporal_activation_fn=temporal_activation_fn, + temporal_cross_attention_dim=temporal_cross_attention_dim, temporal_max_seq_length=temporal_max_seq_length, ) elif down_block_type == "CrossAttnDownBlockMotion": @@ -123,11 +117,8 @@ def get_down_block( only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, - temporal_norm_num_groups=temporal_norm_num_groups, - temporal_cross_attention_dim=temporal_cross_attention_dim, temporal_num_attention_heads=temporal_num_attention_heads, - temporal_attention_bias=temporal_attention_bias, - temporal_activation_fn=temporal_activation_fn, + temporal_cross_attention_dim=temporal_cross_attention_dim, temporal_max_seq_length=temporal_max_seq_length, ) @@ -153,11 +144,8 @@ def get_up_block( only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", - temporal_norm_num_groups=32, - temporal_cross_attention_dim=None, temporal_num_attention_heads=8, - temporal_attention_bias=False, - temporal_activation_fn="geglu", + temporal_cross_attention_dim=None, temporal_max_seq_length=32, ): if up_block_type == "UpBlock3D": @@ -209,11 +197,8 @@ def get_up_block( resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, resolution_idx=resolution_idx, - temporal_norm_num_groups=temporal_norm_num_groups, - temporal_cross_attention_dim=temporal_cross_attention_dim, temporal_num_attention_heads=temporal_num_attention_heads, - temporal_attention_bias=temporal_attention_bias, - temporal_activation_fn=temporal_activation_fn, + temporal_cross_attention_dim=temporal_cross_attention_dim, temporal_max_seq_length=temporal_max_seq_length, ) elif up_block_type == "CrossAttnUpBlockMotion": @@ -237,11 +222,8 @@ def get_up_block( upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resolution_idx=resolution_idx, - temporal_norm_num_groups=temporal_norm_num_groups, - temporal_cross_attention_dim=temporal_cross_attention_dim, temporal_num_attention_heads=temporal_num_attention_heads, - temporal_attention_bias=temporal_attention_bias, - temporal_activation_fn=temporal_activation_fn, + temporal_cross_attention_dim=temporal_cross_attention_dim, temporal_max_seq_length=temporal_max_seq_length, ) raise ValueError(f"{up_block_type} does not exist.") @@ -853,11 +835,8 @@ def __init__( output_scale_factor=1.0, add_downsample=True, downsample_padding=1, - temporal_norm_num_groups=32, + temporal_num_attention_heads=1, temporal_cross_attention_dim=None, - temporal_num_attention_heads=8, - temporal_attention_bias=False, - temporal_activation_fn="geglu", temporal_max_seq_length=32, ): super().__init__() @@ -884,14 +863,13 @@ def __init__( TransformerTemporalModel( num_attention_heads=temporal_num_attention_heads, in_channels=out_channels, - norm_num_groups=temporal_norm_num_groups, + norm_num_groups=resnet_groups, cross_attention_dim=temporal_cross_attention_dim, - attention_bias=temporal_attention_bias, - activation_fn=temporal_activation_fn, - use_positional_embedding=True, - max_seq_length=temporal_max_seq_length, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, attention_head_dim=out_channels // temporal_num_attention_heads, - apply_framewise_group_norm=True, ) ) @@ -975,11 +953,8 @@ def __init__( only_cross_attention=False, upcast_attention=False, attention_type="default", - temporal_norm_num_groups=32, temporal_cross_attention_dim=None, temporal_num_attention_heads=8, - temporal_attention_bias=False, - temporal_activation_fn="geglu", temporal_max_seq_length=32, ): super().__init__() @@ -1038,14 +1013,13 @@ def __init__( TransformerTemporalModel( num_attention_heads=temporal_num_attention_heads, in_channels=out_channels, - norm_num_groups=temporal_norm_num_groups, + norm_num_groups=resnet_groups, cross_attention_dim=temporal_cross_attention_dim, - attention_bias=temporal_attention_bias, - activation_fn=temporal_activation_fn, - use_positional_embedding=True, - max_seq_length=temporal_max_seq_length, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, attention_head_dim=out_channels // temporal_num_attention_heads, - apply_framewise_group_norm=True, ) ) @@ -1166,11 +1140,8 @@ def __init__( only_cross_attention=False, upcast_attention=False, attention_type="default", - temporal_norm_num_groups=32, temporal_cross_attention_dim=None, temporal_num_attention_heads=8, - temporal_attention_bias=False, - temporal_activation_fn="geglu", temporal_max_seq_length=32, ): super().__init__() @@ -1230,14 +1201,13 @@ def __init__( TransformerTemporalModel( num_attention_heads=temporal_num_attention_heads, in_channels=out_channels, - norm_num_groups=temporal_norm_num_groups, + norm_num_groups=resnet_groups, cross_attention_dim=temporal_cross_attention_dim, - attention_bias=temporal_attention_bias, - activation_fn=temporal_activation_fn, - use_positional_embedding=True, - max_seq_length=temporal_max_seq_length, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, attention_head_dim=out_channels // temporal_num_attention_heads, - apply_framewise_group_norm=True, ) ) @@ -1363,8 +1333,6 @@ def __init__( temporal_norm_num_groups=32, temporal_cross_attention_dim=None, temporal_num_attention_heads=8, - temporal_attention_bias=False, - temporal_activation_fn="geglu", temporal_max_seq_length=32, ): super().__init__() @@ -1396,12 +1364,11 @@ def __init__( in_channels=out_channels, norm_num_groups=temporal_norm_num_groups, cross_attention_dim=temporal_cross_attention_dim, - attention_bias=temporal_attention_bias, - activation_fn=temporal_activation_fn, - use_positional_embedding=True, - max_seq_length=temporal_max_seq_length, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, attention_head_dim=out_channels // temporal_num_attention_heads, - apply_framewise_group_norm=True, ) ) @@ -1500,11 +1467,8 @@ def __init__( use_linear_projection=False, upcast_attention=False, attention_type="default", - temporal_norm_num_groups=32, + temporal_num_attention_heads=1, temporal_cross_attention_dim=None, - temporal_num_attention_heads=8, - temporal_attention_bias=False, - temporal_activation_fn="geglu", temporal_max_seq_length=32, ): super().__init__() @@ -1576,13 +1540,12 @@ def __init__( num_attention_heads=temporal_num_attention_heads, attention_head_dim=in_channels // temporal_num_attention_heads, in_channels=in_channels, - norm_num_groups=temporal_norm_num_groups, + norm_num_groups=resnet_groups, cross_attention_dim=temporal_cross_attention_dim, - attention_bias=temporal_attention_bias, - use_positional_embedding=True, - activation_fn=temporal_activation_fn, - max_seq_length=temporal_max_seq_length, - apply_framewise_group_norm=True, + attention_bias=False, + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + activation_fn="geglu", ) ) diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index 2ab1d4060e17..14d8589d55e3 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn @@ -287,9 +287,6 @@ def attn_processors(self) -> Dict[str, AttentionProcessor]: processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) - for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) @@ -300,72 +297,6 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice - def set_attention_slice(self, slice_size): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module splits the input tensor in slices to compute attention in - several steps. This is useful for saving some memory in exchange for a small decrease in speed. - - Args: - slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If - `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - sliceable_head_dims = [] - - def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): - if hasattr(module, "set_attention_slice"): - sliceable_head_dims.append(module.sliceable_head_dim) - - for child in module.children(): - fn_recursive_retrieve_sliceable_dims(child) - - # retrieve number of attention layers - for module in self.children(): - fn_recursive_retrieve_sliceable_dims(module) - - num_sliceable_layers = len(sliceable_head_dims) - - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = [dim // 2 for dim in sliceable_head_dims] - elif slice_size == "max": - # make smallest slice possible - slice_size = num_sliceable_layers * [1] - - slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size - - if len(slice_size) != len(sliceable_head_dims): - raise ValueError( - f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" - f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." - ) - - for i in range(len(slice_size)): - size = slice_size[i] - dim = sliceable_head_dims[i] - if size is not None and size > dim: - raise ValueError(f"size {size} has to be smaller or equal to {dim}.") - - # Recursively walk through all the children. - # Any children which exposes the set_attention_slice method - # gets the message - def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): - if hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size.pop()) - - for child in module.children(): - fn_recursive_set_attention_slice(child, slice_size) - - reversed_slice_size = list(reversed(slice_size)) - for module in self.children(): - fn_recursive_set_attention_slice(module, reversed_slice_size) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor( self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False @@ -502,7 +433,6 @@ def forward( sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, - class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index 3d9431a4f6fc..c78240de1c84 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn @@ -30,6 +30,7 @@ from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .transformer_temporal import TransformerTemporalModel +from .unet_2d_blocks import UNetMidBlock2DCrossAttn from .unet_2d_condition import UNet2DConditionModel from .unet_3d_blocks import ( CrossAttnDownBlockMotion, @@ -71,8 +72,8 @@ def __init__( attention_bias=attention_bias, num_attention_heads=num_attention_heads, attention_head_dim=in_channels // num_attention_heads, - use_positional_embedding=True, - max_seq_length=max_seq_length, + positional_embeddings="sinusoidal", + num_positional_embeddings=max_seq_length, ) ) @@ -90,6 +91,7 @@ def __init__( motion_activation_fn="geglu", motion_norm_num_groups=32, motion_max_seq_length=32, + use_motion_mid_block=True, ): """Container to store AnimateDiff Motion Modules @@ -125,16 +127,19 @@ def __init__( ) ) - self.mid_block = MotionModules( - in_channels=block_out_channels[-1], - norm_num_groups=motion_norm_num_groups, - cross_attention_dim=motion_cross_attention_dim, - activation_fn=motion_activation_fn, - attention_bias=motion_attention_bias, - num_attention_heads=motion_num_attention_heads, - layers_per_block=motion_mid_block_layers_per_block, - max_seq_length=motion_max_seq_length, - ) + if use_motion_mid_block: + self.mid_block = MotionModules( + in_channels=block_out_channels[-1], + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=motion_cross_attention_dim, + activation_fn=motion_activation_fn, + attention_bias=motion_attention_bias, + num_attention_heads=motion_num_attention_heads, + layers_per_block=motion_mid_block_layers_per_block, + max_seq_length=motion_max_seq_length, + ) + else: + self.mid_block = None reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] @@ -199,13 +204,10 @@ def __init__( attention_head_dim: Union[int, Tuple[int]] = 8, use_linear_projection: bool = False, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, - motion_norm_num_groups: Optional[int] = 32, motion_cross_attention_dim: Optional[int] = None, - motion_num_attention_heads: Optional[int] = 8, - motion_attention_bias: bool = False, - motion_activation_fn: str = "geglu", motion_max_seq_length: Optional[int] = 32, - motion_apply_framewise_groupnorm: bool = True, + motion_num_attention_heads: int = 8, + use_motion_mid_block: int = True, ): super().__init__() @@ -288,33 +290,41 @@ def __init__( downsample_padding=downsample_padding, use_linear_projection=use_linear_projection, dual_cross_attention=False, - temporal_norm_num_groups=motion_norm_num_groups, - temporal_cross_attention_dim=motion_cross_attention_dim, temporal_num_attention_heads=motion_num_attention_heads, - temporal_attention_bias=motion_attention_bias, - temporal_activation_fn=motion_activation_fn, + temporal_cross_attention_dim=motion_cross_attention_dim, temporal_max_seq_length=motion_max_seq_length, ) self.down_blocks.append(down_block) # mid - self.mid_block = UNetMidBlockCrossAttnMotion( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads[-1], - resnet_groups=norm_num_groups, - dual_cross_attention=False, - temporal_norm_num_groups=motion_norm_num_groups, - temporal_cross_attention_dim=motion_cross_attention_dim, - temporal_num_attention_heads=motion_num_attention_heads, - temporal_attention_bias=motion_attention_bias, - temporal_activation_fn=motion_activation_fn, - temporal_max_seq_length=motion_max_seq_length, - ) + if use_motion_mid_block: + self.mid_block = UNetMidBlockCrossAttnMotion( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + temporal_cross_attention_dim=motion_cross_attention_dim, + temporal_num_attention_heads=motion_num_attention_heads, + temporal_max_seq_length=motion_max_seq_length, + ) + + else: + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + ) # count how many layers upsample the images self.num_upsamplers = 0 @@ -354,11 +364,8 @@ def __init__( dual_cross_attention=False, resolution_idx=i, use_linear_projection=use_linear_projection, - temporal_norm_num_groups=motion_norm_num_groups, - temporal_cross_attention_dim=motion_cross_attention_dim, temporal_num_attention_heads=motion_num_attention_heads, - temporal_attention_bias=motion_attention_bias, - temporal_activation_fn=motion_activation_fn, + temporal_cross_attention_dim=motion_cross_attention_dim, temporal_max_seq_length=motion_max_seq_length, ) self.up_blocks.append(up_block) @@ -411,14 +418,10 @@ def from_unet2d( config["up_block_types"] = up_blocks if has_motion_adapter: - config["motion_norm_num_groups"] = motion_adapter.config["motion_norm_num_groups"] config["motion_cross_attention_dim"] = motion_adapter.config["motion_cross_attention_dim"] config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"] - config["motion_attention_bias"] = motion_adapter.config["motion_attention_bias"] - config["motion_activation_fn"] = motion_adapter.config["motion_activation_fn"] config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"] - config["motion_layers_per_block"] = motion_adapter.config["motion_layers_per_block"] - config["motion_mid_block_layers_per_block"] = motion_adapter.config["motion_mid_block_layers_per_block"] + config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"] model = cls.from_config(config) @@ -455,6 +458,9 @@ def from_unet2d( if has_motion_adapter: model.load_motion_modules(motion_adapter) + # ensure that the Motion UNet is the same dtype as the UNet2DConditionModel + model.to(unet.dtype) + return model def freeze_unet2d_params(self): @@ -488,7 +494,9 @@ def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]): for i, up_block in enumerate(motion_adapter.up_blocks): self.up_blocks[i].motion_modules.load_state_dict(up_block.motion_modules.state_dict()) - self.mid_block.motion_modules.load_state_dict(motion_adapter.mid_block.motion_modules.state_dict()) + # to support older motion modules that don't have a mid_block + if hasattr(self.mid_block, "motion_modules"): + self.mid_block.motion_modules.load_state_dict(motion_adapter.mid_block.motion_modules.state_dict()) def save_motion_modules( self, @@ -509,13 +517,14 @@ def save_motion_modules( adapter = MotionAdapter( block_out_channels=self.config["block_out_channels"], - motion_norm_num_groups=self.config["motion_norm_num_groups"], - motion_cross_attention_dim=self.config["motion_cross_attention_dim"], + motion_layers_per_block=self.config["layers_per_block"], + motion_norm_num_groups=self.config["norm_num_groups"], motion_num_attention_heads=self.config["motion_num_attention_heads"], - motion_attention_bias=self.config["motion_attention_bias"], - motion_activation_fn=self.config["motion_activation_fn"], + motion_cross_attention_dim=self.config["motion_cross_attention_dim"], + motion_attention_bias=False, + motion_activation_fn="geglu", motion_max_seq_length=self.config["motion_max_seq_length"], - motion_layers_per_block=self.config["layers_per_block"], + use_motion_mid_block=self.config["use_motion_mid_block"], ) adapter.load_state_dict(motion_state_dict) adapter.save_pretrained( @@ -552,72 +561,6 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice - def set_attention_slice(self, slice_size): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module splits the input tensor in slices to compute attention in - several steps. This is useful for saving some memory in exchange for a small decrease in speed. - - Args: - slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If - `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - sliceable_head_dims = [] - - def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): - if hasattr(module, "set_attention_slice"): - sliceable_head_dims.append(module.sliceable_head_dim) - - for child in module.children(): - fn_recursive_retrieve_sliceable_dims(child) - - # retrieve number of attention layers - for module in self.children(): - fn_recursive_retrieve_sliceable_dims(module) - - num_sliceable_layers = len(sliceable_head_dims) - - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = [dim // 2 for dim in sliceable_head_dims] - elif slice_size == "max": - # make smallest slice possible - slice_size = num_sliceable_layers * [1] - - slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size - - if len(slice_size) != len(sliceable_head_dims): - raise ValueError( - f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" - f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." - ) - - for i in range(len(slice_size)): - size = slice_size[i] - dim = sliceable_head_dims[i] - if size is not None and size > dim: - raise ValueError(f"size {size} has to be smaller or equal to {dim}.") - - # Recursively walk through all the children. - # Any children which exposes the set_attention_slice method - # gets the message - def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): - if hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size.pop()) - - for child in module.children(): - fn_recursive_set_attention_slice(child, slice_size) - - reversed_slice_size = list(reversed(slice_size)) - for module in self.children(): - fn_recursive_set_attention_slice(module, reversed_slice_size) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor( self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index b722952b903d..9e69cd945790 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -312,6 +312,7 @@ def encode_prompt( def decode_latents(self, latents): batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) latents = latents / self.vae.config.scaling_factor @@ -323,8 +324,8 @@ def decode_latents(self, latents): output_frames.append(frame) output = torch.cat(output_frames) - video = output[None, :].permute(0, 2, 1, 3, 4) - video = video + _, channels, height, width = output.shape + video = output[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) return video @@ -468,7 +469,6 @@ def __call__( callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, - **kwargs, ): r""" The call function to the pipeline for generation. @@ -492,8 +492,6 @@ def __call__( negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. @@ -540,7 +538,7 @@ def __call__( height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor - num_images_per_prompt = 1 + num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -589,7 +587,7 @@ def __call__( # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( - batch_size * num_images_per_prompt, + batch_size * num_videos_per_prompt, num_channels_latents, num_frames, height, diff --git a/tests/models/test_models_unet_motion.py b/tests/models/test_models_unet_motion.py index 4bce4f3968d4..ff1533975613 100644 --- a/tests/models/test_models_unet_motion.py +++ b/tests/models/test_models_unet_motion.py @@ -159,30 +159,6 @@ def test_xformers_enable_works(self): == "XFormersAttnProcessor" ), "xformers is not enabled" - def test_model_attention_slicing(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = 8 - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - model.set_attention_slice("auto") - with torch.no_grad(): - output = model(**inputs_dict) - assert output is not None - - model.set_attention_slice("max") - with torch.no_grad(): - output = model(**inputs_dict) - assert output is not None - - model.set_attention_slice(2) - with torch.no_grad(): - output = model(**inputs_dict) - assert output is not None - def test_gradient_checkpointing_is_applied(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model_class_copy = copy.copy(self.model_class) @@ -316,7 +292,6 @@ def test_forward_with_norm_groups(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict["norm_num_groups"] = 16 - init_dict["motion_norm_num_groups"] = 16 init_dict["block_out_channels"] = (16, 32) model = self.model_class(**init_dict) diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 6a728bef375d..0d596fe4d4b9 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -1,8 +1,9 @@ import gc import unittest - -import torch +import diffusers +from diffusers.utils import logging import numpy as np +import torch from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import ( @@ -13,18 +14,33 @@ UNet2DConditionModel, UNetMotionModel, ) +from diffusers.utils.testing_utils import numpy_cosine_similarity_distance, require_torch_gpu, slow, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin -from diffusers.utils.testing_utils import torch_device, slow, require_torch_gpu, numpy_cosine_similarity_distance +from ..test_pipelines_common import PipelineTesterMixin + + +def to_np(tensor): + if isinstance(tensor, torch.Tensor): + tensor = tensor.detach().cpu().numpy() + return tensor -class AnimateDiffPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): + +class AnimateDiffPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = AnimateDiffPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - image_params = TEXT_TO_IMAGE_IMAGE_PARAMS - image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback", + "callback_steps", + ] + ) def get_dummy_components(self): torch.manual_seed(0) @@ -42,9 +58,8 @@ def get_dummy_components(self): scheduler = DDIMScheduler( beta_start=0.00085, beta_end=0.012, - beta_schedule="scaled_linear", + beta_schedule="linear", clip_sample=False, - set_alpha_to_one=False, ) torch.manual_seed(0) vae = AutoencoderKL( @@ -69,7 +84,12 @@ def get_dummy_components(self): ) text_encoder = CLIPTextModel(text_encoder_config) tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - motion_adapter = MotionAdapter(block_out_channels=(4, 8), motion_layers_per_block=2, motion_norm_num_groups=2) + motion_adapter = MotionAdapter( + block_out_channels=(32, 64), + motion_layers_per_block=2, + motion_norm_num_groups=2, + motion_num_attention_heads=4, + ) components = { "unet": unet, @@ -78,8 +98,6 @@ def get_dummy_components(self): "motion_adapter": motion_adapter, "text_encoder": text_encoder, "tokenizer": tokenizer, - "safety_checker": None, - "feature_extractor": None, } return components @@ -94,16 +112,113 @@ def get_dummy_inputs(self, device, seed=0): "generator": generator, "num_inference_steps": 2, "guidance_scale": 7.5, - "output_type": "np", + "output_type": "pt", } return inputs def test_motion_unet_loading(self): components = self.get_dummy_components() - pipe = AnimateDiffPipeline.from_pretrained(**components) + pipe = AnimateDiffPipeline(**components) assert isinstance(pipe.unet, UNetMotionModel) + @unittest.skip("Attention slicing is not enabled in this pipeline") + def test_attention_slicing_forward_pass(self): + pass + + def test_inference_batch_single_identical( + self, + batch_size=2, + expected_max_diff=1e-4, + additional_params_copy_to_batched_inputs=["num_inference_steps"], + ): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for components in pipe.components.values(): + if hasattr(components, "set_default_attn_processor"): + components.set_default_attn_processor() + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + inputs = self.get_dummy_inputs(torch_device) + # Reset generator in case it is has been used in self.get_dummy_inputs + inputs["generator"] = self.get_generator(0) + + logger = logging.get_logger(pipe.__module__) + logger.setLevel(level=diffusers.logging.FATAL) + + # batchify inputs + batched_inputs = {} + batched_inputs.update(inputs) + + for name in self.batch_params: + if name not in inputs: + continue + + value = inputs[name] + if name == "prompt": + len_prompt = len(value) + batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)] + batched_inputs[name][-1] = 100 * "very long" + + else: + batched_inputs[name] = batch_size * [value] + + if "generator" in inputs: + batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)] + + if "batch_size" in inputs: + batched_inputs["batch_size"] = batch_size + + for arg in additional_params_copy_to_batched_inputs: + batched_inputs[arg] = inputs[arg] + + output = pipe(**inputs) + output_batch = pipe(**batched_inputs) + + assert output_batch[0].shape[0] == batch_size + + max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max() + assert max_diff < expected_max_diff + + @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices") + def test_to_device(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + pipe.to("cpu") + # pipeline creates a new motion UNet under the hood. So we need to check the device from pipe.components + model_devices = [ + component.device.type for component in pipe.components.values() if hasattr(component, "device") + ] + self.assertTrue(all(device == "cpu" for device in model_devices)) + + output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] + self.assertTrue(np.isnan(output_cpu).sum() == 0) + + pipe.to("cuda") + model_devices = [ + component.device.type for component in pipe.components.values() if hasattr(component, "device") + ] + self.assertTrue(all(device == "cuda" for device in model_devices)) + + output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0] + self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0) + + def test_to_dtype(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + # pipeline creates a new motion UNet under the hood. So we need to check the dtype from pipe.components + model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) + + pipe.to(torch_dtype=torch.float16) + model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) + @slow @require_torch_gpu @@ -115,8 +230,7 @@ def tearDown(self): torch.cuda.empty_cache() def test_animatediff(self): - # make sure here that pndm scheduler skips prk - adapter = MotionAdapter.from_pretrained("diffusers/motion-adapter-test") + adapter = MotionAdapter.from_pretrained("dn6/animatediff-test") pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter) pipe = pipe.to(torch_device) pipe.scheduler = DDIMScheduler( From 2b78f1edb69efd3c01653cdc5c62ba7787570fb7 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 1 Nov 2023 08:33:09 +0000 Subject: [PATCH 36/55] make style --- tests/pipelines/animatediff/test_animatediff.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 0d596fe4d4b9..50af170a47ce 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -1,11 +1,11 @@ import gc import unittest -import diffusers -from diffusers.utils import logging + import numpy as np import torch from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer +import diffusers from diffusers import ( AnimateDiffPipeline, AutoencoderKL, @@ -14,9 +14,10 @@ UNet2DConditionModel, UNetMotionModel, ) +from diffusers.utils import logging from diffusers.utils.testing_utils import numpy_cosine_similarity_distance, require_torch_gpu, slow, torch_device -from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineTesterMixin From 3f5d8dec4b7a714d5d41f24720371e14ccd3a512 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 1 Nov 2023 10:01:55 +0000 Subject: [PATCH 37/55] update --- src/diffusers/models/transformer_temporal.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index c9c0b57ef113..90de1a92914b 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -48,13 +48,15 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). - This is fixed during training since it is used to learn a number of position embeddings. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. attention_bias (`bool`, *optional*): Configure if the `TransformerBlock` attention should contain a bias parameter. double_self_attention (`bool`, *optional*): Configure if each `TransformerBlock` should contain two self-attention layers. + positional_embeddings: (`str`, *optional*): + The type of positional embeddings to apply to the sequence input before passing use. + num_positional_embeddings: (`int`, *optional*): + The maximum length of the sequence over which to apply positional embeddings. """ @register_to_config From d9393799062ac7376832d19e11a1410a21b64806 Mon Sep 17 00:00:00 2001 From: DN6 Date: Wed, 1 Nov 2023 16:35:08 +0530 Subject: [PATCH 38/55] fix embeddings --- src/diffusers/models/embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 6b4c749202e8..f86daec4bf0c 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -263,7 +263,7 @@ class SinusoidalPositionalEmbedding(nn.Module): def __init__(self, embed_dim: int, max_seq_length: int = 32): super().__init__() - pe = get_1d_sincos_pos_embed_from_grid(embed_dim, torch.arange(32)) + pe = get_1d_sincos_pos_embed_from_grid(embed_dim, torch.arange(max_seq_length)) pe = torch.tensor(pe, dtype=torch.float32).unsqueeze(0) self.register_buffer("pe", pe) From 9e6a146ad16e4f6d285c7612aa2decc39bd11773 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 1 Nov 2023 14:47:25 +0000 Subject: [PATCH 39/55] update --- src/diffusers/models/embeddings.py | 7 +++++-- src/diffusers/models/unet_motion_model.py | 19 +++++++++++-------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index f86daec4bf0c..c5fd450e652e 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -263,8 +263,11 @@ class SinusoidalPositionalEmbedding(nn.Module): def __init__(self, embed_dim: int, max_seq_length: int = 32): super().__init__() - pe = get_1d_sincos_pos_embed_from_grid(embed_dim, torch.arange(max_seq_length)) - pe = torch.tensor(pe, dtype=torch.float32).unsqueeze(0) + position = torch.arange(max_seq_length).unsqueeze(1) + div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)) + pe = torch.zeros(1, max_seq_length, embed_dim) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe) def forward(self, x): diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index c78240de1c84..fb5c2511daf3 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -823,14 +823,17 @@ def forward( # 4. mid if self.mid_block is not None: - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - num_frames=num_frames, - cross_attention_kwargs=cross_attention_kwargs, - ) + inputs = { + "hidden_states": sample, + "temb": emb, + "encoder_hidden_states": encoder_hidden_states, + "attention_mask": attention_mask, + "cross_attention_kwargs": cross_attention_kwargs, + } + if hasattr(self.mid_block, "motion_modules"): + inputs.update({"num_frames": num_frames}) + + sample = self.mid_block(**inputs) if mid_block_additional_residual is not None: sample = sample + mid_block_additional_residual From dc6eb04b4ec4aceb7f59f60b63b2e849f49215d9 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 1 Nov 2023 15:30:34 +0000 Subject: [PATCH 40/55] merge upstream --- src/diffusers/models/attention.py | 2 +- src/diffusers/pipelines/animatediff/pipeline_animatediff.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 3e28f5c550ad..cb2f24a52786 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -20,7 +20,7 @@ from ..utils.torch_utils import maybe_allow_in_graph from .activations import GEGLU, GELU, ApproximateGELU from .attention_processor import Attention -from .embeddings import CombinedTimestepLabelEmbeddings, SinusoidalPositionalEmbedding +from .embeddings import SinusoidalPositionalEmbedding from .lora import LoRACompatibleLinear from .normalization import AdaLayerNorm, AdaLayerNormZero diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 9e69cd945790..381384fa14bb 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -33,7 +33,7 @@ LMSDiscreteScheduler, PNDMScheduler, ) -from ...utils import BaseOutput, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -176,7 +176,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -304,7 +304,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) From 5f003e5f15b251bb651b60856b4a578bdf9ffff1 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 1 Nov 2023 15:38:02 +0000 Subject: [PATCH 41/55] max fix copies --- src/diffusers/models/unet_3d_condition.py | 3 ++ src/diffusers/models/unet_motion_model.py | 36 ++++++++++--------- src/diffusers/utils/dummy_pt_objects.py | 30 ++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 15 ++++++++ 4 files changed, 68 insertions(+), 16 deletions(-) diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index dcce165152f7..b07d2be35aae 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -287,6 +287,9 @@ def attn_processors(self) -> Dict[str, AttentionProcessor]: processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index fb5c2511daf3..d02088a619e4 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -688,7 +688,7 @@ def disable_freeu(self): freeu_keys = {"s1", "s2", "b1", "b2"} for i, upsample_block in enumerate(self.up_blocks): for k in freeu_keys: - if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None: + if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: setattr(upsample_block, k, None) # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.forward @@ -709,7 +709,7 @@ def forward( Args: sample (`torch.FloatTensor`): - The noisy input tensor with the following shape `(batch, channel, num_frames, height, width`. + The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.FloatTensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. @@ -775,7 +775,7 @@ def forward( timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - batch, channels, num_frames, height, width = sample.shape + num_frames = sample.shape[2] timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) @@ -790,9 +790,16 @@ def forward( encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) # 2. pre-process - sample = sample.permute(0, 2, 1, 3, 4).reshape((batch * num_frames, channels, height, width)) + sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) sample = self.conv_in(sample) + sample = self.transformer_in( + sample, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: @@ -802,8 +809,8 @@ def forward( temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) @@ -823,17 +830,14 @@ def forward( # 4. mid if self.mid_block is not None: - inputs = { - "hidden_states": sample, - "temb": emb, - "encoder_hidden_states": encoder_hidden_states, - "attention_mask": attention_mask, - "cross_attention_kwargs": cross_attention_kwargs, - } - if hasattr(self.mid_block, "motion_modules"): - inputs.update({"num_frames": num_frames}) - - sample = self.mid_block(**inputs) + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) if mid_block_additional_residual is not None: sample = sample + mid_block_additional_residual diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 890f836c73c6..d6d74a89cafb 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -77,6 +77,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class MotionAdapter(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class MultiAdapter(metaclass=DummyObject): _backends = ["torch"] @@ -212,6 +227,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class UNetMotionModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class VQModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 3b5e3ad4e07d..2fd80f321e6b 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -32,6 +32,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class AnimateDiffPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class AudioLDM2Pipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 6f6f8aa258582aa8408d295e8141d611cdcd62dd Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 1 Nov 2023 16:06:46 +0000 Subject: [PATCH 42/55] fix bug --- src/diffusers/models/unet_motion_model.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index d02088a619e4..958d067e612b 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -691,7 +691,6 @@ def disable_freeu(self): if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: setattr(upsample_block, k, None) - # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.forward def forward( self, sample: torch.FloatTensor, @@ -793,13 +792,6 @@ def forward( sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) sample = self.conv_in(sample) - sample = self.transformer_in( - sample, - num_frames=num_frames, - cross_attention_kwargs=cross_attention_kwargs, - return_dict=False, - )[0] - # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: From ec8bb6e1195f063127846c9a8d8631b74a3f8ce9 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 1 Nov 2023 16:09:22 +0000 Subject: [PATCH 43/55] fix mistake --- src/diffusers/models/unet_3d_condition.py | 68 ++++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index b07d2be35aae..8765bc482ecd 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -275,6 +275,72 @@ def __init__( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + @property # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: From d41f71783b814223ace713915c71623348d59697 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 1 Nov 2023 16:56:37 +0000 Subject: [PATCH 44/55] add docs --- docs/source/en/_toctree.yml | 4 +- docs/source/en/api/models/unet-motion.md | 13 +++++ docs/source/en/api/pipelines/animatediff.md | 63 +++++++++++++++++++++ 3 files changed, 78 insertions(+), 2 deletions(-) create mode 100644 docs/source/en/api/models/unet-motion.md create mode 100644 docs/source/en/api/pipelines/animatediff.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index dbe920030cae..7f3c57c6e5bd 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -208,9 +208,9 @@ title: Overview - local: api/pipelines/alt_diffusion title: AltDiffusion - - local: api/pipelines/attend_and_excite - local: api/pipelienes/animatediff title: AnimateDiff + - local: api/pipelines/attend_and_excite title: Attend-and-Excite - local: api/pipelines/audio_diffusion title: Audio Diffusion @@ -396,5 +396,5 @@ title: Utilities - local: api/image_processor title: VAE Image Processor - title: Internal classes + title: Internal classes title: API diff --git a/docs/source/en/api/models/unet-motion.md b/docs/source/en/api/models/unet-motion.md new file mode 100644 index 000000000000..07d4df64c35f --- /dev/null +++ b/docs/source/en/api/models/unet-motion.md @@ -0,0 +1,13 @@ +# UNetMotionModel + +The [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 2D UNet model. + +The abstract from the paper is: + +*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.* + +## UNetMotionModel +[[autodoc]] UNetMotionModel + +## UNet3DConditionOutput +[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput diff --git a/docs/source/en/api/pipelines/animatediff.md b/docs/source/en/api/pipelines/animatediff.md new file mode 100644 index 000000000000..7793b74ff340 --- /dev/null +++ b/docs/source/en/api/pipelines/animatediff.md @@ -0,0 +1,63 @@ + + +# Text-to-Video Generation with AnimateDiff + +## Overview + +[AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725) by Yuwei Guo, Ceyuan Yang*, Anyi Rao, Yaohui Wang, Yu Qiao, Dahua Lin, Bo Dai + +The abstract of the paper is the following: + +With the advance of text-to-image models (e.g., Stable Diffusion) and corresponding personalization techniques such as DreamBooth and LoRA, everyone can manifest their imagination into high-quality images at an affordable cost. Subsequently, there is a great demand for image animation techniques to further combine generated static images with motion dynamics. In this report, we propose a practical framework to animate most of the existing personalized text-to-image models once and for all, saving efforts in model-specific tuning. At the core of the proposed framework is to insert a newly initialized motion modeling module into the frozen text-to-image model and train it on video clips to distill reasonable motion priors. Once trained, by simply injecting this motion modeling module, all personalized versions derived from the same base T2I readily become text-driven models that produce diverse and personalized animated images. We conduct our evaluation on several public representative personalized text-to-image models across anime pictures and realistic photographs, and demonstrate that our proposed framework helps these models generate temporally smooth animation clips while preserving the domain and diversity of their outputs. Code and pre-trained weights will be publicly available at this https URL . + +## Available Pipelines: + +| Pipeline | Tasks | Demo +|---|---|:---:| +| [AnimateDiffPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff.py) | *Text-to-Video Generation with AnimateDiff* | + +## Usage example + +AnimateDiff works with a MotionAdapter checkpoint and a Stable Diffusion model checkpoint. The MotionAdapter is a collection of Motion Modules that are responsible for adding coherent motion across image frames. These modules are applied after the Resnet and Attention blocks in Stable Diffusion UNet. + +In the following we give a simple example of how to use a *MotionAdapter* checkpoint with Diffusers for inference based on StableDiffusion-1.4/1.5. + +```python +import torch +from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler +from diffusers.utils import export_to_gif + +# Load the motion adapter +adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2") +pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter) +pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False) +pipe.enable_model_cpu_offload() + +output = pipe( + prompt="masterpiece, best quality, 1boy, jacket, beard, walking, beanie, sunglasses, from below, looking up, fisheye, upper body, wasteland, sunset, solo focus, cloudy sky, backpack, hands in pockets", + negative_prompt="human, worst quality, low quality, letterboxed", + num_frames=16, + guidance_scale=7.5, + num_inference_steps=25, + generator = torch.Generator("cpu").manual_seed(42) +) +frames = output.frames[0] +export_to_gif(frames, "animation.gif") +``` + +## Available checkpoints + +Motion Adapter checkpoints can be found under [guoyww/animatediff](https://huggingface.co/guoyww/). + +These checkpoints will work with any model based on Stable Diffusion 1.4/1.5 + From 6d81f2aabec5325b3ec69fb0eae4ea9992aeefea Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 2 Nov 2023 09:04:20 +0000 Subject: [PATCH 45/55] update --- docs/source/en/api/pipelines/animatediff.md | 37 ++++++++++++-- src/diffusers/models/transformer_temporal.py | 3 ++ src/diffusers/models/unet_3d_blocks.py | 9 ---- src/diffusers/models/unet_3d_condition.py | 51 ++++++++++--------- src/diffusers/models/unet_motion_model.py | 48 +++++------------ .../alt_diffusion/pipeline_alt_diffusion.py | 1 - .../pipeline_alt_diffusion_img2img.py | 1 - .../animatediff/pipeline_animatediff.py | 41 ++++++++++++++- tests/models/test_models_unet_motion.py | 5 +- 9 files changed, 114 insertions(+), 82 deletions(-) diff --git a/docs/source/en/api/pipelines/animatediff.md b/docs/source/en/api/pipelines/animatediff.md index 7793b74ff340..377cab1149d3 100644 --- a/docs/source/en/api/pipelines/animatediff.md +++ b/docs/source/en/api/pipelines/animatediff.md @@ -30,7 +30,7 @@ With the advance of text-to-image models (e.g., Stable Diffusion) and correspond AnimateDiff works with a MotionAdapter checkpoint and a Stable Diffusion model checkpoint. The MotionAdapter is a collection of Motion Modules that are responsible for adding coherent motion across image frames. These modules are applied after the Resnet and Attention blocks in Stable Diffusion UNet. -In the following we give a simple example of how to use a *MotionAdapter* checkpoint with Diffusers for inference based on StableDiffusion-1.4/1.5. +The following example demonstrates how to use a *MotionAdapter* checkpoint with Diffusers for inference based on StableDiffusion-1.4/1.5. ```python import torch @@ -39,8 +39,17 @@ from diffusers.utils import export_to_gif # Load the motion adapter adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2") +# load SD 1.5 based finetuned model pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter) -pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False) +pipe.scheduler = DDIMScheduler( + beta_schedule="linear", + steps_offset=1, + clip_sample=False, + beta_start=0.00085, + beta_end=0.012, + timestep_spacing="linspace", +) +# enable memory savings pipe.enable_model_cpu_offload() output = pipe( @@ -55,9 +64,27 @@ frames = output.frames[0] export_to_gif(frames, "animation.gif") ``` -## Available checkpoints + + +AnimateDiff tends to work better with finetuned Stable Diffusion models. If you plan on using a scheduler that can clip samples, make sure to disable it by setting `clip_sample=False` in the scheduler as this can also have an adverse effect on generated samples. + + -Motion Adapter checkpoints can be found under [guoyww/animatediff](https://huggingface.co/guoyww/). +## AnimateDiffPipeline +[[autodoc]] AnimateDiffPipeline + - all + - __call__ + - enable_freeu + - disable_freeu + - enable_vae_slicing + - disable_vae_slicing + - enable_vae_tiling + - disable_vae_tiling -These checkpoints will work with any model based on Stable Diffusion 1.4/1.5 +## AnimateDiffPipelineOutput + +[[autodoc]] pipelines.animatediff.AnimateDiffPipelineOutput + +## Available checkpoints +Motion Adapter checkpoints can be found under [guoyww](https://huggingface.co/guoyww/). These checkpoints are meant to work with any model based on Stable Diffusion 1.4/1.5 diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index 618161f1fa0c..efbd20dce64e 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -50,6 +50,8 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. attention_bias (`bool`, *optional*): Configure if the `TransformerBlock` attention should contain a bias parameter. + sample_size: (`int`, *optional*): + This is fixed during training since it is used to learn a number of position embeddings. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported activation functions. @@ -74,6 +76,7 @@ def __init__( norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, + sample_size: Optional[int] = None, activation_fn: str = "geglu", norm_elementwise_affine: bool = True, double_self_attention: bool = True, diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index 699725a12f48..7b556fb0cb8d 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -44,7 +44,6 @@ def get_down_block( upcast_attention=False, resnet_time_scale_shift="default", temporal_num_attention_heads=8, - temporal_cross_attention_dim=None, temporal_max_seq_length=32, ): if down_block_type == "DownBlock3D": @@ -94,7 +93,6 @@ def get_down_block( downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, temporal_num_attention_heads=temporal_num_attention_heads, - temporal_cross_attention_dim=temporal_cross_attention_dim, temporal_max_seq_length=temporal_max_seq_length, ) elif down_block_type == "CrossAttnDownBlockMotion": @@ -118,7 +116,6 @@ def get_down_block( upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, temporal_num_attention_heads=temporal_num_attention_heads, - temporal_cross_attention_dim=temporal_cross_attention_dim, temporal_max_seq_length=temporal_max_seq_length, ) @@ -198,7 +195,6 @@ def get_up_block( resnet_time_scale_shift=resnet_time_scale_shift, resolution_idx=resolution_idx, temporal_num_attention_heads=temporal_num_attention_heads, - temporal_cross_attention_dim=temporal_cross_attention_dim, temporal_max_seq_length=temporal_max_seq_length, ) elif up_block_type == "CrossAttnUpBlockMotion": @@ -223,7 +219,6 @@ def get_up_block( resnet_time_scale_shift=resnet_time_scale_shift, resolution_idx=resolution_idx, temporal_num_attention_heads=temporal_num_attention_heads, - temporal_cross_attention_dim=temporal_cross_attention_dim, temporal_max_seq_length=temporal_max_seq_length, ) raise ValueError(f"{up_block_type} does not exist.") @@ -1095,8 +1090,6 @@ def custom_forward(*inputs): )[0] hidden_states = motion_module( hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, num_frames=num_frames, )[0] @@ -1301,8 +1294,6 @@ def custom_forward(*inputs): )[0] hidden_states = motion_module( hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, num_frames=num_frames, )[0] diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index 8765bc482ecd..7356fb577584 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -275,6 +275,31 @@ def __init__( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size): r""" @@ -341,31 +366,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) - @property - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor( self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False @@ -502,6 +502,7 @@ def forward( sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index 958d067e612b..52bdae5f803c 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -86,9 +86,6 @@ def __init__( motion_layers_per_block=2, motion_mid_block_layers_per_block=1, motion_num_attention_heads=8, - motion_attention_bias=False, - motion_cross_attention_dim=None, - motion_activation_fn="geglu", motion_norm_num_groups=32, motion_max_seq_length=32, use_motion_mid_block=True, @@ -118,9 +115,9 @@ def __init__( MotionModules( in_channels=output_channel, norm_num_groups=motion_norm_num_groups, - cross_attention_dim=motion_cross_attention_dim, - activation_fn=motion_activation_fn, - attention_bias=motion_attention_bias, + cross_attention_dim=None, + activation_fn="geglu", + attention_bias=False, num_attention_heads=motion_num_attention_heads, max_seq_length=motion_max_seq_length, layers_per_block=motion_layers_per_block, @@ -131,9 +128,9 @@ def __init__( self.mid_block = MotionModules( in_channels=block_out_channels[-1], norm_num_groups=motion_norm_num_groups, - cross_attention_dim=motion_cross_attention_dim, - activation_fn=motion_activation_fn, - attention_bias=motion_attention_bias, + cross_attention_dim=None, + activation_fn="geglu", + attention_bias=False, num_attention_heads=motion_num_attention_heads, layers_per_block=motion_mid_block_layers_per_block, max_seq_length=motion_max_seq_length, @@ -149,9 +146,9 @@ def __init__( MotionModules( in_channels=output_channel, norm_num_groups=motion_norm_num_groups, - cross_attention_dim=motion_cross_attention_dim, - activation_fn=motion_activation_fn, - attention_bias=motion_attention_bias, + cross_attention_dim=None, + activation_fn="geglu", + attention_bias=False, num_attention_heads=motion_num_attention_heads, max_seq_length=motion_max_seq_length, layers_per_block=motion_layers_per_block + 1, @@ -201,10 +198,8 @@ def __init__( norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, - attention_head_dim: Union[int, Tuple[int]] = 8, use_linear_projection: bool = False, - num_attention_heads: Optional[Union[int, Tuple[int]]] = None, - motion_cross_attention_dim: Optional[int] = None, + num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, motion_max_seq_length: Optional[int] = 32, motion_num_attention_heads: int = 8, use_motion_mid_block: int = True, @@ -213,19 +208,6 @@ def __init__( self.sample_size = sample_size - if num_attention_heads is not None: - raise NotImplementedError( - "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." - ) - - # If `num_attention_heads` is not defined (which is the case for most models) - # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. - # The reason for this behavior is to correct for incorrectly named variables that were introduced - # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 - # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking - # which is why we correct for the naming here. - num_attention_heads = num_attention_heads or attention_head_dim - # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( @@ -291,7 +273,6 @@ def __init__( use_linear_projection=use_linear_projection, dual_cross_attention=False, temporal_num_attention_heads=motion_num_attention_heads, - temporal_cross_attention_dim=motion_cross_attention_dim, temporal_max_seq_length=motion_max_seq_length, ) self.down_blocks.append(down_block) @@ -308,7 +289,6 @@ def __init__( num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=False, - temporal_cross_attention_dim=motion_cross_attention_dim, temporal_num_attention_heads=motion_num_attention_heads, temporal_max_seq_length=motion_max_seq_length, ) @@ -365,7 +345,6 @@ def __init__( resolution_idx=i, use_linear_projection=use_linear_projection, temporal_num_attention_heads=motion_num_attention_heads, - temporal_cross_attention_dim=motion_cross_attention_dim, temporal_max_seq_length=motion_max_seq_length, ) self.up_blocks.append(up_block) @@ -418,7 +397,6 @@ def from_unet2d( config["up_block_types"] = up_blocks if has_motion_adapter: - config["motion_cross_attention_dim"] = motion_adapter.config["motion_cross_attention_dim"] config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"] config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"] config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"] @@ -520,9 +498,6 @@ def save_motion_modules( motion_layers_per_block=self.config["layers_per_block"], motion_norm_num_groups=self.config["norm_num_groups"], motion_num_attention_heads=self.config["motion_num_attention_heads"], - motion_cross_attention_dim=self.config["motion_cross_attention_dim"], - motion_attention_bias=False, - motion_activation_fn="geglu", motion_max_seq_length=self.config["motion_max_seq_length"], use_motion_mid_block=self.config["use_motion_mid_block"], ) @@ -598,6 +573,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) + # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking def enable_forward_chunking(self, chunk_size=None, dim=0): """ Sets the attention processor to use [feed forward @@ -627,6 +603,7 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int for module in self.children(): fn_recursive_feed_forward(module, chunk_size, dim) + # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking def disable_forward_chunking(self): def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): @@ -638,6 +615,7 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int for module in self.children(): fn_recursive_feed_forward(module, None, 0) + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 3c24db1fdc94..bf267f0ff1af 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -106,7 +106,6 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ - model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index d9acf9daf2a6..a28c3da2694b 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -134,7 +134,6 @@ class AltDiffusionImg2ImgPipeline( feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ - model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 381384fa14bb..24855d4f09a8 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -91,8 +91,10 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLo Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer (`CLIPTokenizer`): A [`~transformers.CLIPTokenizer`] to tokenize text. - unet ([`UNet3DConditionModel`]): - A [`UNet3DConditionModel`] to denoise the encoded video latents. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. + motion_adapter ([`MotionAdapter`]): + A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. @@ -329,6 +331,39 @@ def decode_latents(self, latents): return video + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. @@ -357,6 +392,7 @@ def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -422,6 +458,7 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) + # Copied from diffusers.pipelines.text_to_video_synthesis.TextToVideoSDPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): diff --git a/tests/models/test_models_unet_motion.py b/tests/models/test_models_unet_motion.py index ff1533975613..60c3399db537 100644 --- a/tests/models/test_models_unet_motion.py +++ b/tests/models/test_models_unet_motion.py @@ -69,7 +69,7 @@ def prepare_init_args_and_inputs_for_common(self): "down_block_types": ("CrossAttnDownBlockMotion", "DownBlockMotion"), "up_block_types": ("UpBlockMotion", "CrossAttnUpBlockMotion"), "cross_attention_dim": 32, - "attention_head_dim": 8, + "num_attention_heads": 4, "out_channels": 4, "in_channels": 4, "layers_per_block": 1, @@ -213,9 +213,6 @@ def test_feed_forward_chunking(self): def test_pickle(self): # enable deterministic behavior for gradient checkpointing init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = (8, 16) - model = self.model_class(**init_dict) model.to(torch_device) From 840f576a7fafb535f51f080009eda395eb488c95 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 2 Nov 2023 09:21:00 +0000 Subject: [PATCH 46/55] clean up --- docs/source/en/_toctree.yml | 2 +- src/diffusers/pipelines/animatediff/pipeline_animatediff.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 7f3c57c6e5bd..954b81cb4ec5 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -208,7 +208,7 @@ title: Overview - local: api/pipelines/alt_diffusion title: AltDiffusion - - local: api/pipelienes/animatediff + - local: api/pipelines/animatediff title: AnimateDiff - local: api/pipelines/attend_and_excite title: Attend-and-Excite diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 24855d4f09a8..d8e9128d34f5 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -458,7 +458,7 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - # Copied from diffusers.pipelines.text_to_video_synthesis.TextToVideoSDPipeline.prepare_latents + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): From a6d025befc73e62c67a3586eec9b9ad53e02b416 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 2 Nov 2023 09:53:19 +0000 Subject: [PATCH 47/55] update --- src/diffusers/models/unet_motion_model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index 52bdae5f803c..120ba6cc97de 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -371,7 +371,6 @@ def from_unet2d( unet: UNet2DConditionModel, motion_adapter: Optional[MotionAdapter] = None, load_weights: bool = True, - **kwargs, ): has_motion_adapter = motion_adapter is not None @@ -401,6 +400,10 @@ def from_unet2d( config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"] config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"] + # Need this for backwards compatibility with UNet2DConditionModel checkpoints + if not config.get("num_attention_heads"): + config["num_attention_heads"] = config["attention_head_dim"] + model = cls.from_config(config) if not load_weights: From ee51b907cd72a7991bb1b9041fc652c8c4373ad6 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 2 Nov 2023 09:57:12 +0000 Subject: [PATCH 48/55] clean up --- src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py | 1 + .../pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index bf267f0ff1af..3c24db1fdc94 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -106,6 +106,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ + model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index a28c3da2694b..d9acf9daf2a6 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -134,6 +134,7 @@ class AltDiffusionImg2ImgPipeline( feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ + model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] From dfa52fbc81255cc39123d14bfb14898a05db9505 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 2 Nov 2023 10:21:44 +0000 Subject: [PATCH 49/55] clean up --- src/diffusers/models/unet_motion_model.py | 33 +++++++++++++++-------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index 120ba6cc97de..c0381f791bba 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -463,9 +463,10 @@ def freeze_unet2d_params(self): for param in motion_modules.parameters(): param.requires_grad = True - motion_modules = self.mid_block.motion_modules - for param in motion_modules.parameters(): - param.requires_grad = True + if hasattr(self.mid_block, "motion_modules"): + motion_modules = self.mid_block.motion_modules + for param in motion_modules.parameters(): + param.requires_grad = True return @@ -803,14 +804,24 @@ def forward( # 4. mid if self.mid_block is not None: - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - num_frames=num_frames, - cross_attention_kwargs=cross_attention_kwargs, - ) + # To support older versions of motion modules that don't have a mid_block + if hasattr(self.mid_block, "motion_modules"): + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) if mid_block_additional_residual is not None: sample = sample + mid_block_additional_residual From c24c97b70bd0371f34cb4ba58883742551c07295 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 2 Nov 2023 10:30:17 +0000 Subject: [PATCH 50/55] fix docstrings --- src/diffusers/models/unet_motion_model.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index c0381f791bba..b3f5e251d3a9 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -95,14 +95,18 @@ def __init__( Args: block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each UNet block. - layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. - num_attention_heads (`int`, *optional*): - The number of heads to use in each attention layer. - attention_bias (bool, optional, defaults to False): Whether to include bias in attention layers. - cross_attention_dim (int, optional, Defaults to None): Set in order to use cross attention. - activation_fn (str, optional, Defaults to "geglu"): Activation Function. - norm_num_groups (int, optional): _description_. Defaults to 32. - max_seq_length (int, optional): _description_. Defaults to 24. + motion_layers_per_block (`int`, *optional*, defaults to 2): + The number of motion layers per UNet block. + motion_mid_block_layers_per_block (`int`, *optional*, defaults to 1): + The number of motion layers in the middle UNet block. + motion_num_attention_heads (`int`, *optional*, defaults to 8): + The number of heads to use in each attention layer of the motion module. + motion_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use in each group normalization layer of the motion module. + motion_max_seq_length (`int`, *optional*, defaults to 32): + The maximum sequence length to use in the motion module. + use_motion_mid_block (`bool`, *optional*, defaults to True): + Whether to use a motion module in the middle of the UNet. """ super().__init__() @@ -686,7 +690,7 @@ def forward( return_dict: bool = True, ) -> Union[UNet3DConditionOutput, Tuple]: r""" - The [`UNet3DConditionModel`] forward method. + The [`UNetMotionModel`] forward method. Args: sample (`torch.FloatTensor`): From ef893c4b384469de3b252536dfed0865f582890c Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 2 Nov 2023 10:35:04 +0000 Subject: [PATCH 51/55] fix docstrings --- src/diffusers/models/unet_motion_model.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index b3f5e251d3a9..5d528a34ec96 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -698,8 +698,6 @@ def forward( timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.FloatTensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. - class_labels (`torch.Tensor`, *optional*, defaults to `None`): - Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed through the `self.time_embedding` layer to obtain the timestep embeddings. @@ -718,8 +716,6 @@ def forward( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain tuple. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. Returns: [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: From a2e38cce48abc9515f614293ef2e5d15c934242b Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 2 Nov 2023 12:33:38 +0000 Subject: [PATCH 52/55] update --- docs/source/en/api/pipelines/animatediff.md | 40 +++++++++++----- src/diffusers/models/transformer_temporal.py | 1 + .../animatediff/pipeline_animatediff.py | 46 +++++++++++-------- 3 files changed, 56 insertions(+), 31 deletions(-) diff --git a/docs/source/en/api/pipelines/animatediff.md b/docs/source/en/api/pipelines/animatediff.md index 377cab1149d3..ff621c60221d 100644 --- a/docs/source/en/api/pipelines/animatediff.md +++ b/docs/source/en/api/pipelines/animatediff.md @@ -40,30 +40,48 @@ from diffusers.utils import export_to_gif # Load the motion adapter adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2") # load SD 1.5 based finetuned model -pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter) -pipe.scheduler = DDIMScheduler( - beta_schedule="linear", - steps_offset=1, - clip_sample=False, - beta_start=0.00085, - beta_end=0.012, - timestep_spacing="linspace", +model_id = "SG161222/Realistic_Vision_V5.1_noVAE" +pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter) +scheduler = DDIMScheduler.from_pretrained( + model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1 ) +pipe.scheduler = scheduler + # enable memory savings +pipe.enable_vae_slicing() pipe.enable_model_cpu_offload() output = pipe( - prompt="masterpiece, best quality, 1boy, jacket, beard, walking, beanie, sunglasses, from below, looking up, fisheye, upper body, wasteland, sunset, solo focus, cloudy sky, backpack, hands in pockets", - negative_prompt="human, worst quality, low quality, letterboxed", + prompt=( + "masterpiece, bestquality, highlydetailed, ultradetailed, sunset, " + "orange sky, warm lighting, fishing boats, ocean waves seagulls, " + "rippling water, wharf, silhouette, serene atmosphere, dusk, evening glow, " + "golden hour, coastal landscape, seaside scenery" + ), + negative_prompt="bad quality, worse quality", num_frames=16, guidance_scale=7.5, num_inference_steps=25, - generator = torch.Generator("cpu").manual_seed(42) + generator=torch.Generator("cpu").manual_seed(42), ) frames = output.frames[0] export_to_gif(frames, "animation.gif") ``` +Here are some sample outputs: + + + + + +
+ masterpiece, bestquality, sunset. +
+ masterpiece, bestquality, sunset +
+ AnimateDiff tends to work better with finetuned Stable Diffusion models. If you plan on using a scheduler that can clip samples, make sure to disable it by setting `clip_sample=False` in the scheduler as this can also have an adverse effect on generated samples. diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index efbd20dce64e..eb0e1cde4af9 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -71,6 +71,7 @@ def __init__( num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, + out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index d8e9128d34f5..650b447cd23a 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -131,11 +131,12 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( self, prompt, device, - num_videos_per_prompt, + num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, @@ -151,7 +152,7 @@ def encode_prompt( prompt to be encoded device: (`torch.device`): torch device - num_videos_per_prompt (`int`): + num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not @@ -249,8 +250,8 @@ def encode_prompt( bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: @@ -303,32 +304,37 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder) + unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents): - batch_size, channels, num_frames, height, width = latents.shape + latents = 1 / self.vae.config.scaling_factor * latents + batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) - latents = latents / self.vae.config.scaling_factor - - output_frames = [] - - # decode frame by frame to avoid OOM - for frame_idx in range(latents.shape[0]): - frame = self.vae.decode(latents[frame_idx].unsqueeze(0), return_dict=False)[0] - output_frames.append(frame) - - output = torch.cat(output_frames) - _, channels, height, width = output.shape - video = output[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) + image = self.vae.decode(latents).sample + video = ( + image[None, :] + .reshape( + ( + batch_size, + num_frames, + -1, + ) + + image.shape[2:] + ) + .permute(0, 2, 1, 3, 4) + ) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() return video # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing From 0d6f5be46d14aab64027909f70725b89ed844dee Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 2 Nov 2023 12:37:54 +0000 Subject: [PATCH 53/55] update --- src/diffusers/models/transformer_temporal.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index eb0e1cde4af9..0b328858fac8 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -89,7 +89,6 @@ def __init__( self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim - self.use_cross_attention = cross_attention_dim is not None self.in_channels = in_channels self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) @@ -169,7 +168,6 @@ def forward( hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) hidden_states = self.proj_in(hidden_states) - encoder_hidden_states = encoder_hidden_states if self.use_cross_attention else None # 2. Blocks for block in self.transformer_blocks: From beb1646b1c0ba72ef8b733ef4b629459c6317482 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 2 Nov 2023 13:06:31 +0000 Subject: [PATCH 54/55] clean up --- src/diffusers/models/transformer_temporal.py | 2 +- src/diffusers/models/unet_3d_blocks.py | 2 -- .../pipelines/animatediff/test_animatediff.py | 29 ++++++++++++++----- 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index 0b328858fac8..c84c766cc7b0 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -50,7 +50,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. attention_bias (`bool`, *optional*): Configure if the `TransformerBlock` attention should contain a bias parameter. - sample_size: (`int`, *optional*): + sample_size: (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). This is fixed during training since it is used to learn a number of position embeddings. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index 7b556fb0cb8d..e8e42cf5615f 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -1604,8 +1604,6 @@ def custom_forward(*inputs): )[0] hidden_states = motion_module( hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, num_frames=num_frames, )[0] hidden_states = resnet(hidden_states, temb, scale=lora_scale) diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 50af170a47ce..baba8ba4d655 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -231,7 +231,7 @@ def tearDown(self): torch.cuda.empty_cache() def test_animatediff(self): - adapter = MotionAdapter.from_pretrained("dn6/animatediff-test") + adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2") pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter) pipe = pipe.to(torch_device) pipe.scheduler = DDIMScheduler( @@ -241,26 +241,39 @@ def test_animatediff(self): steps_offset=1, clip_sample=False, ) + pipe.enable_vae_slicing() + pipe.enable_model_cpu_offload() pipe.set_progress_bar_config(disable=None) prompt = "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain" negative_prompt = "bad quality, worse quality" - generator = torch.manual_seed(0) + generator = torch.Generator("cpu").manual_seed(0) output = pipe( prompt, negative_prompt=negative_prompt, num_frames=16, generator=generator, guidance_scale=7.5, - num_inference_steps=20, + num_inference_steps=3, output_type="np", ) - image = output.images + image = output.frames[0] + assert image.shape == (16, 512, 512, 3) image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 512, 512, 3) - expected_slice = np.array([0.1010, 0.0800, 0.0794, 0.0885, 0.0843, 0.0762, 0.0769, 0.0729, 0.0586]) - - assert numpy_cosine_similarity_distance(image_slice.flatten() - expected_slice.flatten()) < 1e-2 + expected_slice = np.array( + [ + 0.11357737, + 0.11285847, + 0.11180121, + 0.11084166, + 0.11414117, + 0.09785956, + 0.10742754, + 0.10510018, + 0.08045256, + ] + ) + assert numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice.flatten()) < 1e-3 From 88e76c6aeeecf326d62b4ad5630d7da19d144f5e Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 2 Nov 2023 13:21:17 +0000 Subject: [PATCH 55/55] update --- src/diffusers/models/transformer_temporal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index c84c766cc7b0..2e053d70eaa7 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -50,7 +50,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. attention_bias (`bool`, *optional*): Configure if the `TransformerBlock` attention should contain a bias parameter. - sample_size: (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). This is fixed during training since it is used to learn a number of position embeddings. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported