-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added support to create asymmetrical U-Net structures #5400
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this. The changes seem clean and simple to me.
@patrickvonplaten WDYT?
Updated the docstring for ```UNet2DConditionModel``` to include ```reverse_transformer_layers_per_block``` and updated checking for nested list type ```transformer_layers_per_block```
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I would need to see a clear use case for when asymmetrical U-Nets make sense. I'm a bit hesitant to adapt the Unet2DCondition class for use case that have not been verified to work yet.
Do we have a checkpoint that shows good results with an asymmetrical U-Net?
The checkpoint we're using for the SDXL distillation is asymmetrical. We did put those results up on the slack channel. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, yes this design also works for me. Could we add a quick test here? :-)
No problem! May I ask what exactly you have in mind? Instantiating 1-2 UNet from configs of the new type? Can I get a few examples for tests written for similar PRs? |
Hi @Gothos! You can check out this test suite to get an idea: I'd instantiate an asymmetric UNet in that test suite and run dummy inference and perform assertions as you'd see for some of the tests being done in the test suite mentioned above. Let me know if that makes sense? |
Yes, thanks! I'll get this done as quickly as I can. |
@sayakpaul I've written a basic test asserting input shape equals the output shape for an asymmetrical UNet. Please let me know if anything's wrong/ any other tests are required. |
@sayakpaul I've done the needful. |
Running the commands above to enforce consistency seems to have attempted to make "UNetMidBlockFlat" a viable mid block type under the versatile diffusion pipeline in |
We need to enforce repository consistency, sadly. Is it possible to copy over the implementation in |
…ensure consistency
Sure! I've just done it and rerun the makes twice to make sure nothing popped up. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good work!
@@ -587,7 +587,7 @@ def __init__( | |||
self.attentions = nn.ModuleList(attentions) | |||
self.resnets = nn.ModuleList(resnets) | |||
|
|||
def forward(self, hidden_states, temb=None): | |||
def forward(self, hidden_states, temb=None, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def forward(self, hidden_states, temb=None, *args, **kwargs): | |
def forward(self, hidden_states, temb=None): |
Is this really needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove this, I don't see exactly how this is required here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
``` # 4. mid
if self.mid_block is not None:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)```
These lines make the *args, **kwargs necessary. I'll go by the style elsewhere and use the has attention
attribute instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)
else:
sample = self.mid_block(
sample,
emb,
)
Would following your suggestion above and then doing this work, @patrickvonplaten?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sounds good to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All right, I'll commit the changes. Just wanted to check once.
@patrickvonplaten could you give this a look too? |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for working on this! PR is good to go from my end.
Wrap into single line Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
* Added args, kwargs to ```U * Add UNetMidBlock2D as a supported mid block type * Fix extra init input for UNetMidBlock2D, change allowed types for Mid-block init * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_blocks.py * Update unet_2d_blocks.py * Update unet_2d_blocks.py * Update unet_2d_condition.py * Update unet_2d_blocks.py * Updated docstring, increased check strictness Updated the docstring for ```UNet2DConditionModel``` to include ```reverse_transformer_layers_per_block``` and updated checking for nested list type ```transformer_layers_per_block``` * Add basic shape-check test for asymmetrical unets * Update src/diffusers/models/unet_2d_blocks.py Removed blank line Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update unet_2d_condition.py Remove blank space * Update unet_2d_condition.py Changed docstring for `mid_block_type` * Fixed docstring and wrong default value * Reformat with black * Reformat with necessary commands * Add UNetMidBlockFlat to versatile_diffusion/modeling_text_unet.py to ensure consistency * Removed args, kwargs, use on mid-block type * Make fix-copies * Update src/diffusers/models/unet_2d_condition.py Wrap into single line Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * make fix-copies --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
* Added args, kwargs to ```U * Add UNetMidBlock2D as a supported mid block type * Fix extra init input for UNetMidBlock2D, change allowed types for Mid-block init * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_blocks.py * Update unet_2d_blocks.py * Update unet_2d_blocks.py * Update unet_2d_condition.py * Update unet_2d_blocks.py * Updated docstring, increased check strictness Updated the docstring for ```UNet2DConditionModel``` to include ```reverse_transformer_layers_per_block``` and updated checking for nested list type ```transformer_layers_per_block``` * Add basic shape-check test for asymmetrical unets * Update src/diffusers/models/unet_2d_blocks.py Removed blank line Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update unet_2d_condition.py Remove blank space * Update unet_2d_condition.py Changed docstring for `mid_block_type` * Fixed docstring and wrong default value * Reformat with black * Reformat with necessary commands * Add UNetMidBlockFlat to versatile_diffusion/modeling_text_unet.py to ensure consistency * Removed args, kwargs, use on mid-block type * Make fix-copies * Update src/diffusers/models/unet_2d_condition.py Wrap into single line Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * make fix-copies --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
* Added args, kwargs to ```U * Add UNetMidBlock2D as a supported mid block type * Fix extra init input for UNetMidBlock2D, change allowed types for Mid-block init * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_blocks.py * Update unet_2d_blocks.py * Update unet_2d_blocks.py * Update unet_2d_condition.py * Update unet_2d_blocks.py * Updated docstring, increased check strictness Updated the docstring for ```UNet2DConditionModel``` to include ```reverse_transformer_layers_per_block``` and updated checking for nested list type ```transformer_layers_per_block``` * Add basic shape-check test for asymmetrical unets * Update src/diffusers/models/unet_2d_blocks.py Removed blank line Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update unet_2d_condition.py Remove blank space * Update unet_2d_condition.py Changed docstring for `mid_block_type` * Fixed docstring and wrong default value * Reformat with black * Reformat with necessary commands * Add UNetMidBlockFlat to versatile_diffusion/modeling_text_unet.py to ensure consistency * Removed args, kwargs, use on mid-block type * Make fix-copies * Update src/diffusers/models/unet_2d_condition.py Wrap into single line Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * make fix-copies --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
What does this PR do?
UNetMidBlock2D
as a supportedmid_block_type
in theUNet2DConditionModel
class.transformer_layers_per_block
for every layer of theUNet2DConditionModel
class.Support is somewhat rigid at the moment, using this functionality requires specifying the
reverse_transformer_layers_per_block
attribute in a config (set toNone
by default).Please comment on any style changes/bugs you find in the code, I'll fix it as soon as I can.
@sayakpaul