Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Beautiful Doc string added into the UNetMidBlock2D class. #5389

Merged
merged 25 commits into from
Oct 18, 2023
Merged
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
8425cd4
I added a new doc string to the class. This is more flexible to under…
hi-sushanta Oct 5, 2023
8a0e77d
Merge branch 'main' into doc_string
hi-sushanta Oct 5, 2023
cf3a816
Merge branch 'main' into doc_string
hi-sushanta Oct 5, 2023
a0cd96f
Merge branch 'main' into doc_string
hi-sushanta Oct 7, 2023
746a8e8
Update src/diffusers/models/unet_2d_blocks.py
hi-sushanta Oct 7, 2023
6e56886
Update src/diffusers/models/unet_2d_blocks.py
hi-sushanta Oct 7, 2023
627fd9f
Update unet_2d_blocks.py
hi-sushanta Oct 7, 2023
ae4f7f2
Update unet_2d_blocks.py
hi-sushanta Oct 8, 2023
f0bea43
Update unet_2d_blocks.py
hi-sushanta Oct 8, 2023
872a4a5
Merge branch 'main' into doc_string
hi-sushanta Oct 8, 2023
3546f6d
I run the black command to reformat style in the code
hi-sushanta Oct 9, 2023
12534f4
Merge branch 'main' into doc_string
hi-sushanta Oct 9, 2023
48afb4b
Merge branch 'main' into doc_string
hi-sushanta Oct 10, 2023
0bb06b6
Merge pull request #1 from hi-sushanta/doc_string
hi-sushanta Oct 11, 2023
01a9fc9
Update unet_2d_blocks.py
hi-sushanta Oct 13, 2023
04e6efb
Merge branch 'huggingface:main' into main
hi-sushanta Oct 13, 2023
eb9611d
Update unet_2d_blocks.py
hi-sushanta Oct 13, 2023
298881a
Merge branch 'main' into doc-string
hi-sushanta Oct 16, 2023
73dcca6
Merge branch 'main' into doc-string
sayakpaul Oct 17, 2023
680e596
Update unet_2d_blocks.py
hi-sushanta Oct 17, 2023
5db0ce9
Update unet_2d_blocks.py
hi-sushanta Oct 17, 2023
ea792c9
Update unet_2d_blocks.py
hi-sushanta Oct 17, 2023
af87b15
Merge branch 'main' into doc-string
hi-sushanta Oct 18, 2023
87a74ce
I am using black package for reformated my file
hi-sushanta Oct 18, 2023
f99aac0
Merge branch 'main' into doc-string
yiyixuxu Oct 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/diffusers/models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,29 @@ def forward(self, x):


class UNetMidBlock2D(nn.Module):
"""
A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.

Args:
in_channels (`int`): The number of input channels.
temb_channels (`int`): The number of temporal embedding channels.
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
resnet_time_scale_shift (`str`, *optional*, defaults to `default`): The type of normalization to apply to the time embeddings. This can help to improve the performance of the model on tasks with long-range temporal dependencies.
resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
resnet_groups (`int`, *optional*, defaults to 32): The number of groups to use in the group normalization layers of the resnet blocks.
attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
resnet_pre_norm (`bool`, *optional*, defaults to `True`): Whether to use pre-normalization for the resnet blocks.
add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
attention_head_dim (`int`, *optional*, defaults to 1): Dimension of a single attention head. The number of attention heads is determined based on this value and the number of input channels.
output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.

Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels, height, width)`.

"""

def __init__(
self,
in_channels: int,
Expand Down
Loading