Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support to create asymmetrical U-Net structures #5400

Merged
merged 36 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
b8dca9f
Added args, kwargs to ```U
Gothos Oct 15, 2023
577d1e2
Add UNetMidBlock2D as a supported mid block type
Gothos Oct 15, 2023
2c8d804
Fix extra init input for UNetMidBlock2D, change allowed types for Mid…
Gothos Oct 15, 2023
e07c615
Update unet_2d_condition.py
Gothos Oct 15, 2023
8344b2d
Update unet_2d_condition.py
Gothos Oct 15, 2023
025b1ec
Update unet_2d_condition.py
Gothos Oct 15, 2023
6230214
Update unet_2d_condition.py
Gothos Oct 15, 2023
80b891b
Update unet_2d_condition.py
Gothos Oct 15, 2023
c4e4d40
Update unet_2d_condition.py
Gothos Oct 15, 2023
6176fd5
Update unet_2d_condition.py
Gothos Oct 15, 2023
73737f8
Update unet_2d_condition.py
Gothos Oct 15, 2023
bc53a32
Update unet_2d_blocks.py
Gothos Oct 15, 2023
314730b
Update unet_2d_blocks.py
Gothos Oct 15, 2023
545a0b0
Update unet_2d_blocks.py
Gothos Oct 15, 2023
bf77268
Update unet_2d_condition.py
Gothos Oct 15, 2023
1998c17
Update unet_2d_blocks.py
Gothos Oct 15, 2023
81d682e
Updated docstring, increased check strictness
Gothos Oct 16, 2023
93ce7d3
Add basic shape-check test for asymmetrical unets
Gothos Oct 17, 2023
e6b937c
Update src/diffusers/models/unet_2d_blocks.py
Gothos Oct 17, 2023
3cf5e20
Merge branch 'main' into main
sayakpaul Oct 17, 2023
0e9f6f4
Update unet_2d_condition.py
Gothos Oct 17, 2023
a91d085
Update unet_2d_condition.py
Gothos Oct 17, 2023
d75f203
Fixed docstring and wrong default value
Gothos Oct 17, 2023
4cf629b
Merge branch 'main' into main
Gothos Oct 17, 2023
97c47b6
Reformat with black
Gothos Oct 17, 2023
7695566
Merge branch 'main' into main
Gothos Oct 17, 2023
b3bf008
Merge branch 'main' into main
Gothos Oct 18, 2023
196ab3e
Reformat with necessary commands
Gothos Oct 18, 2023
a65ec27
Merge branch 'huggingface:main' into main
Gothos Oct 18, 2023
bbebc23
Add UNetMidBlockFlat to versatile_diffusion/modeling_text_unet.py to …
Gothos Oct 18, 2023
b6c0ff6
Merge branch 'main' into main
sayakpaul Oct 19, 2023
0a73108
Removed args, kwargs, use on mid-block type
Gothos Oct 19, 2023
950d969
Merge remote-tracking branch 'refs/remotes/origin/main'
Gothos Oct 19, 2023
bfbf85c
Make fix-copies
Gothos Oct 19, 2023
8389a85
Update src/diffusers/models/unet_2d_condition.py
Gothos Oct 20, 2023
252ca85
make fix-copies
Gothos Oct 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions src/diffusers/models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -617,7 +617,7 @@ def __init__(
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)

def forward(self, hidden_states, temb=None):
def forward(self, hidden_states, temb=None, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def forward(self, hidden_states, temb=None, *args, **kwargs):
def forward(self, hidden_states, temb=None):

Is this really needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove this, I don't see exactly how this is required here

Copy link
Contributor Author

@Gothos Gothos Oct 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

   ``` # 4. mid
    if self.mid_block is not None:
        sample = self.mid_block(
            sample,
            emb,
            encoder_hidden_states=encoder_hidden_states,
            attention_mask=attention_mask,
            cross_attention_kwargs=cross_attention_kwargs,
            encoder_attention_mask=encoder_attention_mask,
        )```

These lines make the *args, **kwargs necessary. I'll go by the style elsewhere and use the has attention attribute instead

Copy link
Contributor Author

@Gothos Gothos Oct 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

            if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
                sample = self.mid_block(
                    sample,
                    emb,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=attention_mask,
                    cross_attention_kwargs=cross_attention_kwargs,
                    encoder_attention_mask=encoder_attention_mask,
                )
            else:
                sample = self.mid_block(
                    sample,
                    emb,
                )

Would following your suggestion above and then doing this work, @patrickvonplaten?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds good to me!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All right, I'll commit the changes. Just wanted to check once.

hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
Expand All @@ -634,7 +634,7 @@ def __init__(
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
Expand All @@ -654,6 +654,10 @@ def __init__(
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)

# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers

# there is always at least one resnet
resnets = [
ResnetBlock2D(
Expand All @@ -671,14 +675,14 @@ def __init__(
]
attentions = []

for _ in range(num_layers):
for i in range(num_layers):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=transformer_layers_per_block,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
Expand Down Expand Up @@ -1018,7 +1022,7 @@ def __init__(
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
Expand All @@ -1041,6 +1045,8 @@ def __init__(

self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers

for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
Expand All @@ -1064,7 +1070,7 @@ def __init__(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
Expand Down Expand Up @@ -2167,7 +2173,7 @@ def __init__(
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
Expand All @@ -2190,6 +2196,9 @@ def __init__(
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads

if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers

for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
Expand All @@ -2214,7 +2223,7 @@ def __init__(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
Expand Down
42 changes: 35 additions & 7 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
from .modeling_utils import ModelMixin
from .unet_2d_blocks import (
UNetMidBlock2D,
UNetMidBlock2DCrossAttn,
UNetMidBlock2DSimpleCrossAttn,
get_down_block,
Expand Down Expand Up @@ -86,7 +87,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
The tuple of upsample blocks to use.
Expand All @@ -105,10 +106,15 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features.
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
encoder_hid_dim (`int`, *optional*, defaults to None):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`.
Expand Down Expand Up @@ -142,9 +148,9 @@ class conditioning with `class_embed_type` equal to `None`.
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
The dimension of `cond_proj` layer in the timestep embedding.
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
*optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
*optional*): The dimension of the `class_labels` input when
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
embeddings with the class embeddings.
Expand Down Expand Up @@ -184,7 +190,8 @@ def __init__(
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: Union[int, Tuple[int]] = 1280,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8,
Expand Down Expand Up @@ -265,6 +272,10 @@ def __init__(
raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
)
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
for layer_number_per_block in transformer_layers_per_block:
if isinstance(layer_number_per_block, list):
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")

# input
conv_in_padding = (conv_in_kernel - 1) // 2
Expand Down Expand Up @@ -500,6 +511,19 @@ def __init__(
only_cross_attention=mid_block_only_cross_attention,
cross_attention_norm=cross_attention_norm,
)
elif mid_block_type == "UNetMidBlock2D":
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=blocks_time_embed_dim,
dropout=dropout,
num_layers=0,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
add_attention=False,
)
elif mid_block_type is None:
self.mid_block = None
else:
Expand All @@ -513,7 +537,11 @@ def __init__(
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
reversed_transformer_layers_per_block = (
list(reversed(transformer_layers_per_block))
if reverse_transformer_layers_per_block is None
else reverse_transformer_layers_per_block
)
only_cross_attention = list(reversed(only_cross_attention))

output_channel = reversed_block_out_channels[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
if "disable_self_attentions" in unet_params:
config["only_cross_attention"] = unet_params.disable_self_attentions

if "num_classes" in unet_params and type(unet_params.num_classes) == int:
if "num_classes" in unet_params and isinstance(unet_params.num_classes, int):
config["num_class_embeds"] = unet_params.num_classes

if controlnet:
Expand Down
Loading
Loading