Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support to create asymmetrical U-Net structures #5400

Merged
merged 36 commits into from
Oct 20, 2023
Merged

Added support to create asymmetrical U-Net structures #5400

merged 36 commits into from
Oct 20, 2023

Conversation

Gothos
Copy link
Contributor

@Gothos Gothos commented Oct 15, 2023

What does this PR do?

  1. Adds UNetMidBlock2D as a supported mid_block_type in the UNet2DConditionModel class.
  2. Adds basic support for specifying transformer_layers_per_block for every layer of the UNet2DConditionModel class.
    Support is somewhat rigid at the moment, using this functionality requires specifying the reverse_transformer_layers_per_block attribute in a config (set to None by default).
    Please comment on any style changes/bugs you find in the code, I'll fix it as soon as I can.

@sayakpaul

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this. The changes seem clean and simple to me.
@patrickvonplaten WDYT?

Updated the docstring for ```UNet2DConditionModel``` to include ```reverse_transformer_layers_per_block``` and updated checking for nested list type ```transformer_layers_per_block```
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I would need to see a clear use case for when asymmetrical U-Nets make sense. I'm a bit hesitant to adapt the Unet2DCondition class for use case that have not been verified to work yet.

Do we have a checkpoint that shows good results with an asymmetrical U-Net?

@Gothos
Copy link
Contributor Author

Gothos commented Oct 16, 2023

The checkpoint we're using for the SDXL distillation is asymmetrical. We did put those results up on the slack channel.
The PR shouldn't affect anything existing at the current stage significantly, I'll be more than happy to do any needful changes quick.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, yes this design also works for me. Could we add a quick test here? :-)

@Gothos
Copy link
Contributor Author

Gothos commented Oct 16, 2023

No problem! May I ask what exactly you have in mind? Instantiating 1-2 UNet from configs of the new type? Can I get a few examples for tests written for similar PRs?

@sayakpaul
Copy link
Member

Hi @Gothos!

You can check out this test suite to get an idea:
https://github.com/huggingface/diffusers/blob/main/tests/models/test_models_unet_2d_condition.py

I'd instantiate an asymmetric UNet in that test suite and run dummy inference and perform assertions as you'd see for some of the tests being done in the test suite mentioned above.

Let me know if that makes sense?

@Gothos
Copy link
Contributor Author

Gothos commented Oct 16, 2023

Yes, thanks! I'll get this done as quickly as I can.

@Gothos
Copy link
Contributor Author

Gothos commented Oct 17, 2023

@sayakpaul I've written a basic test asserting input shape equals the output shape for an asymmetrical UNet. Please let me know if anything's wrong/ any other tests are required.

@Gothos
Copy link
Contributor Author

Gothos commented Oct 18, 2023

@sayakpaul I've done the needful.

@Gothos
Copy link
Contributor Author

Gothos commented Oct 18, 2023

Running the commands above to enforce consistency seems to have attempted to make "UNetMidBlockFlat" a viable mid block type under the versatile diffusion pipeline in modeling_text_unet.py. However this class is not implemented in the file and this causes an error.

@sayakpaul
Copy link
Member

We need to enforce repository consistency, sadly.

Is it possible to copy over the implementation in modeling_text_unet.py?

@Gothos
Copy link
Contributor Author

Gothos commented Oct 18, 2023

We need to enforce repository consistency, sadly.

Is it possible to copy over the implementation in modeling_text_unet.py?

Sure! I've just done it and rerun the makes twice to make sure nothing popped up.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good work!

@@ -587,7 +587,7 @@ def __init__(
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)

def forward(self, hidden_states, temb=None):
def forward(self, hidden_states, temb=None, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def forward(self, hidden_states, temb=None, *args, **kwargs):
def forward(self, hidden_states, temb=None):

Is this really needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove this, I don't see exactly how this is required here

Copy link
Contributor Author

@Gothos Gothos Oct 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

   ``` # 4. mid
    if self.mid_block is not None:
        sample = self.mid_block(
            sample,
            emb,
            encoder_hidden_states=encoder_hidden_states,
            attention_mask=attention_mask,
            cross_attention_kwargs=cross_attention_kwargs,
            encoder_attention_mask=encoder_attention_mask,
        )```

These lines make the *args, **kwargs necessary. I'll go by the style elsewhere and use the has attention attribute instead

Copy link
Contributor Author

@Gothos Gothos Oct 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

            if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
                sample = self.mid_block(
                    sample,
                    emb,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=attention_mask,
                    cross_attention_kwargs=cross_attention_kwargs,
                    encoder_attention_mask=encoder_attention_mask,
                )
            else:
                sample = self.mid_block(
                    sample,
                    emb,
                )

Would following your suggestion above and then doing this work, @patrickvonplaten?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds good to me!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All right, I'll commit the changes. Just wanted to check once.

@sayakpaul
Copy link
Member

@patrickvonplaten could you give this a look too?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on this! PR is good to go from my end.

Gothos and others added 2 commits October 20, 2023 12:22
Wrap into single line

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
@patrickvonplaten patrickvonplaten merged commit 8dba180 into huggingface:main Oct 20, 2023
11 checks passed
mhetrerajat pushed a commit to mhetrerajat/diffusers that referenced this pull request Oct 23, 2023
* Added args, kwargs to ```U

* Add UNetMidBlock2D as a supported mid block type

* Fix extra init input for UNetMidBlock2D, change allowed types for Mid-block init

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_blocks.py

* Update unet_2d_blocks.py

* Update unet_2d_blocks.py

* Update unet_2d_condition.py

* Update unet_2d_blocks.py

* Updated docstring, increased check strictness

Updated the docstring for ```UNet2DConditionModel``` to include ```reverse_transformer_layers_per_block``` and updated checking for nested list type ```transformer_layers_per_block```

* Add basic shape-check test for asymmetrical unets

* Update src/diffusers/models/unet_2d_blocks.py

Removed blank line

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update unet_2d_condition.py

Remove blank space

* Update unet_2d_condition.py

Changed docstring for `mid_block_type`

* Fixed docstring and wrong default value

* Reformat with black

* Reformat with necessary commands

* Add UNetMidBlockFlat to versatile_diffusion/modeling_text_unet.py to ensure consistency

* Removed args, kwargs, use on mid-block type

* Make fix-copies

* Update src/diffusers/models/unet_2d_condition.py

Wrap into single line

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* make fix-copies

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Added args, kwargs to ```U

* Add UNetMidBlock2D as a supported mid block type

* Fix extra init input for UNetMidBlock2D, change allowed types for Mid-block init

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_blocks.py

* Update unet_2d_blocks.py

* Update unet_2d_blocks.py

* Update unet_2d_condition.py

* Update unet_2d_blocks.py

* Updated docstring, increased check strictness

Updated the docstring for ```UNet2DConditionModel``` to include ```reverse_transformer_layers_per_block``` and updated checking for nested list type ```transformer_layers_per_block```

* Add basic shape-check test for asymmetrical unets

* Update src/diffusers/models/unet_2d_blocks.py

Removed blank line

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update unet_2d_condition.py

Remove blank space

* Update unet_2d_condition.py

Changed docstring for `mid_block_type`

* Fixed docstring and wrong default value

* Reformat with black

* Reformat with necessary commands

* Add UNetMidBlockFlat to versatile_diffusion/modeling_text_unet.py to ensure consistency

* Removed args, kwargs, use on mid-block type

* Make fix-copies

* Update src/diffusers/models/unet_2d_condition.py

Wrap into single line

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* make fix-copies

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* Added args, kwargs to ```U

* Add UNetMidBlock2D as a supported mid block type

* Fix extra init input for UNetMidBlock2D, change allowed types for Mid-block init

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_condition.py

* Update unet_2d_blocks.py

* Update unet_2d_blocks.py

* Update unet_2d_blocks.py

* Update unet_2d_condition.py

* Update unet_2d_blocks.py

* Updated docstring, increased check strictness

Updated the docstring for ```UNet2DConditionModel``` to include ```reverse_transformer_layers_per_block``` and updated checking for nested list type ```transformer_layers_per_block```

* Add basic shape-check test for asymmetrical unets

* Update src/diffusers/models/unet_2d_blocks.py

Removed blank line

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update unet_2d_condition.py

Remove blank space

* Update unet_2d_condition.py

Changed docstring for `mid_block_type`

* Fixed docstring and wrong default value

* Reformat with black

* Reformat with necessary commands

* Add UNetMidBlockFlat to versatile_diffusion/modeling_text_unet.py to ensure consistency

* Removed args, kwargs, use on mid-block type

* Make fix-copies

* Update src/diffusers/models/unet_2d_condition.py

Wrap into single line

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* make fix-copies

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants