Skip to content

Commit

Permalink
Beautiful Doc string added into the UNetMidBlock2D class. (#5389)
Browse files Browse the repository at this point in the history
* I added a new doc string to the class. This is more flexible to understanding other developers what are doing and where it's using.

* Update src/diffusers/models/unet_2d_blocks.py

This changes suggest by maintener.

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/models/unet_2d_blocks.py

Add suggested text

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update unet_2d_blocks.py

I changed the Parameter to Args text.

* Update unet_2d_blocks.py

proper indentation set in this file.

* Update unet_2d_blocks.py

a little bit of change in the act_fun argument line.

* I run the black command to reformat style in the code

* Update unet_2d_blocks.py

similar doc-string add to have in the original diffusion repository.

* Update unet_2d_blocks.py

Added Beutifull doc-string into the UNetMidBlock2D class.

* Update unet_2d_blocks.py

I replaced the definition in this parameter resnet_time_scale_shift and resnet_groups.

* Update unet_2d_blocks.py

I remove additional sentences into the resnet_groups argument.

* Update unet_2d_blocks.py

I replaced my definition with the maintainer definition in the attention_head_dim parameter.

* I am using black package for reformated my file

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
  • Loading branch information
3 people authored Oct 18, 2023
1 parent 9ad0530 commit 36a0bac
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions src/diffusers/models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,29 @@ def forward(self, x):


class UNetMidBlock2D(nn.Module):
"""
A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
Args:
in_channels (`int`): The number of input channels.
temb_channels (`int`): The number of temporal embedding channels.
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
resnet_time_scale_shift (`str`, *optional*, defaults to `default`): The type of normalization to apply to the time embeddings. This can help to improve the performance of the model on tasks with long-range temporal dependencies.
resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
resnet_groups (`int`, *optional*, defaults to 32): The number of groups to use in the group normalization layers of the resnet blocks.
attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
resnet_pre_norm (`bool`, *optional*, defaults to `True`): Whether to use pre-normalization for the resnet blocks.
add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
attention_head_dim (`int`, *optional*, defaults to 1): Dimension of a single attention head. The number of attention heads is determined based on this value and the number of input channels.
output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels, height, width)`.
"""

def __init__(
self,
in_channels: int,
Expand Down

0 comments on commit 36a0bac

Please sign in to comment.