diff --git a/animatediff/models/resnet.py b/animatediff/models/resnet.py index ad28eb0c..da80f174 100644 --- a/animatediff/models/resnet.py +++ b/animatediff/models/resnet.py @@ -18,6 +18,17 @@ def forward(self, x): return x +class InflatedGroupNorm(nn.GroupNorm): + def forward(self, x): + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + class Upsample3D(nn.Module): def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): super().__init__() @@ -112,6 +123,7 @@ def __init__( time_embedding_norm="default", output_scale_factor=1.0, use_in_shortcut=None, + use_inflated_groupnorm=None, ): super().__init__() self.pre_norm = pre_norm @@ -126,7 +138,11 @@ def __init__( if groups_out is None: groups_out = groups - self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + assert use_inflated_groupnorm != None + if use_inflated_groupnorm: + self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + else: + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) @@ -142,7 +158,11 @@ def __init__( else: self.time_emb_proj = None - self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + if use_inflated_groupnorm: + self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + else: + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) diff --git a/animatediff/models/unet.py b/animatediff/models/unet.py index 9d67e8ae..18aa9551 100644 --- a/animatediff/models/unet.py +++ b/animatediff/models/unet.py @@ -24,7 +24,7 @@ get_down_block, get_up_block, ) -from .resnet import InflatedConv3d +from .resnet import InflatedConv3d, InflatedGroupNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -77,6 +77,8 @@ def __init__( upcast_attention: bool = False, resnet_time_scale_shift: str = "default", + use_inflated_groupnorm=False, + # Additional use_motion_module = False, motion_module_resolutions = ( 1,2,4,8 ), @@ -88,7 +90,7 @@ def __init__( unet_use_temporal_attention = None, ): super().__init__() - + self.sample_size = sample_size time_embed_dim = block_out_channels[0] * 4 @@ -150,6 +152,7 @@ def __init__( unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only), motion_module_type=motion_module_type, @@ -175,6 +178,7 @@ def __init__( unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, use_motion_module=use_motion_module and motion_module_mid_block, motion_module_type=motion_module_type, @@ -227,6 +231,7 @@ def __init__( unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, use_motion_module=use_motion_module and (res in motion_module_resolutions), motion_module_type=motion_module_type, @@ -236,7 +241,10 @@ def __init__( prev_output_channel = output_channel # out - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + if use_inflated_groupnorm: + self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) self.conv_act = nn.SiLU() self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) diff --git a/animatediff/models/unet_blocks.py b/animatediff/models/unet_blocks.py index 8a17f201..711ad6cc 100644 --- a/animatediff/models/unet_blocks.py +++ b/animatediff/models/unet_blocks.py @@ -30,7 +30,8 @@ def get_down_block( unet_use_cross_frame_attention=None, unet_use_temporal_attention=None, - + use_inflated_groupnorm=None, + use_motion_module=None, motion_module_type=None, @@ -50,6 +51,8 @@ def get_down_block( downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, @@ -77,6 +80,7 @@ def get_down_block( unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, use_motion_module=use_motion_module, motion_module_type=motion_module_type, @@ -106,6 +110,7 @@ def get_up_block( unet_use_cross_frame_attention=None, unet_use_temporal_attention=None, + use_inflated_groupnorm=None, use_motion_module=None, motion_module_type=None, @@ -125,6 +130,8 @@ def get_up_block( resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, @@ -152,6 +159,7 @@ def get_up_block( unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, use_motion_module=use_motion_module, motion_module_type=motion_module_type, @@ -181,6 +189,7 @@ def __init__( unet_use_cross_frame_attention=None, unet_use_temporal_attention=None, + use_inflated_groupnorm=None, use_motion_module=None, @@ -206,6 +215,8 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + + use_inflated_groupnorm=use_inflated_groupnorm, ) ] attentions = [] @@ -248,6 +259,8 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + + use_inflated_groupnorm=use_inflated_groupnorm, ) ) @@ -290,6 +303,7 @@ def __init__( unet_use_cross_frame_attention=None, unet_use_temporal_attention=None, + use_inflated_groupnorm=None, use_motion_module=None, @@ -318,6 +332,8 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + + use_inflated_groupnorm=use_inflated_groupnorm, ) ) if dual_cross_attention: @@ -421,6 +437,8 @@ def __init__( output_scale_factor=1.0, add_downsample=True, downsample_padding=1, + + use_inflated_groupnorm=None, use_motion_module=None, motion_module_type=None, @@ -444,6 +462,8 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + + use_inflated_groupnorm=use_inflated_groupnorm, ) ) motion_modules.append( @@ -526,6 +546,7 @@ def __init__( unet_use_cross_frame_attention=None, unet_use_temporal_attention=None, + use_inflated_groupnorm=None, use_motion_module=None, @@ -556,6 +577,8 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + + use_inflated_groupnorm=use_inflated_groupnorm, ) ) if dual_cross_attention: @@ -661,6 +684,8 @@ def __init__( output_scale_factor=1.0, add_upsample=True, + use_inflated_groupnorm=None, + use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, @@ -685,6 +710,8 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + + use_inflated_groupnorm=use_inflated_groupnorm, ) ) motion_modules.append(