Skip to content

Commit

Permalink
support v2
Browse files Browse the repository at this point in the history
  • Loading branch information
guoyww committed Sep 10, 2023
1 parent 1b50d64 commit 1089219
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 6 deletions.
24 changes: 22 additions & 2 deletions animatediff/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ def forward(self, x):
return x


class InflatedGroupNorm(nn.GroupNorm):
def forward(self, x):
video_length = x.shape[2]

x = rearrange(x, "b c f h w -> (b f) c h w")
x = super().forward(x)
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)

return x


class Upsample3D(nn.Module):
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__()
Expand Down Expand Up @@ -112,6 +123,7 @@ def __init__(
time_embedding_norm="default",
output_scale_factor=1.0,
use_in_shortcut=None,
use_inflated_groupnorm=None,
):
super().__init__()
self.pre_norm = pre_norm
Expand All @@ -126,7 +138,11 @@ def __init__(
if groups_out is None:
groups_out = groups

self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
assert use_inflated_groupnorm != None
if use_inflated_groupnorm:
self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)

self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

Expand All @@ -142,7 +158,11 @@ def __init__(
else:
self.time_emb_proj = None

self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
if use_inflated_groupnorm:
self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
else:
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)

self.dropout = torch.nn.Dropout(dropout)
self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

Expand Down
14 changes: 11 additions & 3 deletions animatediff/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
get_down_block,
get_up_block,
)
from .resnet import InflatedConv3d
from .resnet import InflatedConv3d, InflatedGroupNorm


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -77,6 +77,8 @@ def __init__(
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",

use_inflated_groupnorm=False,

# Additional
use_motion_module = False,
motion_module_resolutions = ( 1,2,4,8 ),
Expand All @@ -88,7 +90,7 @@ def __init__(
unet_use_temporal_attention = None,
):
super().__init__()

self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4

Expand Down Expand Up @@ -150,6 +152,7 @@ def __init__(

unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
use_inflated_groupnorm=use_inflated_groupnorm,

use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
motion_module_type=motion_module_type,
Expand All @@ -175,6 +178,7 @@ def __init__(

unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
use_inflated_groupnorm=use_inflated_groupnorm,

use_motion_module=use_motion_module and motion_module_mid_block,
motion_module_type=motion_module_type,
Expand Down Expand Up @@ -227,6 +231,7 @@ def __init__(

unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
use_inflated_groupnorm=use_inflated_groupnorm,

use_motion_module=use_motion_module and (res in motion_module_resolutions),
motion_module_type=motion_module_type,
Expand All @@ -236,7 +241,10 @@ def __init__(
prev_output_channel = output_channel

# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
if use_inflated_groupnorm:
self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
else:
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)

Expand Down
29 changes: 28 additions & 1 deletion animatediff/models/unet_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def get_down_block(

unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,

use_inflated_groupnorm=None,

use_motion_module=None,

motion_module_type=None,
Expand All @@ -50,6 +51,8 @@ def get_down_block(
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,

use_inflated_groupnorm=use_inflated_groupnorm,

use_motion_module=use_motion_module,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
Expand Down Expand Up @@ -77,6 +80,7 @@ def get_down_block(

unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
use_inflated_groupnorm=use_inflated_groupnorm,

use_motion_module=use_motion_module,
motion_module_type=motion_module_type,
Expand Down Expand Up @@ -106,6 +110,7 @@ def get_up_block(

unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
use_inflated_groupnorm=None,

use_motion_module=None,
motion_module_type=None,
Expand All @@ -125,6 +130,8 @@ def get_up_block(
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,

use_inflated_groupnorm=use_inflated_groupnorm,

use_motion_module=use_motion_module,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
Expand Down Expand Up @@ -152,6 +159,7 @@ def get_up_block(

unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
use_inflated_groupnorm=use_inflated_groupnorm,

use_motion_module=use_motion_module,
motion_module_type=motion_module_type,
Expand Down Expand Up @@ -181,6 +189,7 @@ def __init__(

unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
use_inflated_groupnorm=None,

use_motion_module=None,

Expand All @@ -206,6 +215,8 @@ def __init__(
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,

use_inflated_groupnorm=use_inflated_groupnorm,
)
]
attentions = []
Expand Down Expand Up @@ -248,6 +259,8 @@ def __init__(
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,

use_inflated_groupnorm=use_inflated_groupnorm,
)
)

Expand Down Expand Up @@ -290,6 +303,7 @@ def __init__(

unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
use_inflated_groupnorm=None,

use_motion_module=None,

Expand Down Expand Up @@ -318,6 +332,8 @@ def __init__(
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,

use_inflated_groupnorm=use_inflated_groupnorm,
)
)
if dual_cross_attention:
Expand Down Expand Up @@ -421,6 +437,8 @@ def __init__(
output_scale_factor=1.0,
add_downsample=True,
downsample_padding=1,

use_inflated_groupnorm=None,

use_motion_module=None,
motion_module_type=None,
Expand All @@ -444,6 +462,8 @@ def __init__(
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,

use_inflated_groupnorm=use_inflated_groupnorm,
)
)
motion_modules.append(
Expand Down Expand Up @@ -526,6 +546,7 @@ def __init__(

unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
use_inflated_groupnorm=None,

use_motion_module=None,

Expand Down Expand Up @@ -556,6 +577,8 @@ def __init__(
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,

use_inflated_groupnorm=use_inflated_groupnorm,
)
)
if dual_cross_attention:
Expand Down Expand Up @@ -661,6 +684,8 @@ def __init__(
output_scale_factor=1.0,
add_upsample=True,

use_inflated_groupnorm=None,

use_motion_module=None,
motion_module_type=None,
motion_module_kwargs=None,
Expand All @@ -685,6 +710,8 @@ def __init__(
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,

use_inflated_groupnorm=use_inflated_groupnorm,
)
)
motion_modules.append(
Expand Down

0 comments on commit 1089219

Please sign in to comment.