Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AttentionProcessor support #949

Closed

Conversation

rockerBOO
Copy link
Contributor

@rockerBOO rockerBOO commented Nov 14, 2023

Background: I was working to add DAAM attention mapping heatmaps into the sampling pipeline (and possibly into the training process). For my own purposes but also to propose to you. To do this I needed to add AttentionProcessor like support and many small changes to be able to have DAAM use it's diffusers interface to work with the Kohya original_unet and sampling pipelines.

Diffusers made some changes and added a AttentionProcessor which abstracts out the attention mechanism and allows one to add their own processor in there. The DAAM library hooks into this processor to do cross attention mapping into a heatmap. https://github.com/castorini/daam/blob/v0.1.0/daam/trace.py#L281-L282

This also adds some additional code from the Diffusers Attention class to further support the manipulations of attention that DAAM requires.

I added the dropout architecture back in because DAAM was looking for the last node and it missing caused it to error out. We could make it nn.Identity instead if that's faster.

The encoder_hidden_states was being used on Diffusers side which I believe is related to the context term so I mapped that over. This allows other libraries to pass diffusers like variables or the original context and should work similarly. This happens due to a lot of **kwargs like mappings.

Ultimately I made these changes to limit any further refactoring into maybe a diffusers like architecture and better allow sd-scripts to work with diffusers. I know this is a lot of changes and additions. I can refactor it down if necessary. Thank you!

@ThereforeGames
Copy link

Interesting! What steps are needed to integrate attention maps into the training process? I'm willing to run some tests on my datasets.

@rockerBOO
Copy link
Contributor Author

rockerBOO commented Nov 18, 2023

@ThereforeGames It works but still a little rough. if you can figure out the following steps you could probably try it out.

https://github.com/rockerBOO/sd-scripts/tree/daam_sampling This branch for the sampling.
rockerBOO/daam#2 This PR for the DAAM library.

In the sample file:
a photo of a woman smiling, waving with her hands --d 42 --da woman,hands,smiling

women-2023-11-18-012743-0a5d2382_20231118013628_e000015_15_42
women-2023-11-18-012743-0a5d2382_20231118013628_e000015_15_42-attn-smile
women-2023-11-18-012743-0a5d2382_20231118013628_e000015_15_42-attn-waving
women-2023-11-18-012743-0a5d2382_20231118013628_e000015_15_42-attn-woman

@ThereforeGames
Copy link

Thanks - I imagine even those rough heatmaps could have a benefit on the model's understanding of object relationships. Will have to give this a try!

@kohya-ss
Copy link
Owner

Hello! Sorry for the late reply. This is a great suggestion, and I am very interested in your ideas!

Unfotunately I don't have time to go into details, but please let me just write my basic thoughts first.

I am very sorry, but overall, I am not a big fan of Diffusers' design policy: it seems that Diffusers is trying to be a generic library with very broad support, and it adds a lot of complexity to their code.
For example, their unet_2d_condition.py is over 1,000 lines of code, even though it does not work without other modules.
Also, their updates are frequent and often breaking, which is not easy to keep up with.

I like to keep the sd-scripts repository as simple as possible.
(Yes, I know that some of the source code is very long. I mean architectural simplicity.)

But, of course, I also understand the effectiveness of AttentionProcessor.

So, I think we need to carefully consider how to implement a way to keep the repository simple while interoperable with Diffusers.

Please give me time to think about how to address this issue.

@rockerBOO
Copy link
Contributor Author

Thank you @kohya-ss . I appreciate your kind words.

I was in the middle of working with that library to get this working and the changes I initially proposes with were just enough to get it working. After working with the library and augmenting it to work with CompVis like library I found it was assuming too many things to exist. I have reworked the library completely at this point to better support different types of libraries. So most of these changes I initially proposed won't be necessary to support.

If we look at the current sd-scripts version we are doing an attention processor in the forward. So my proposal has reduced significantly to just adding a self.processor and passing the values into that. This allows us to attach into the attention flow without having to resort to monkey patching. I added the additional names from HF Diffusers so it will translate appropriately. It would error if the names in the call were not there.

The current proposal would be adequate for my needs. The most limited version of what I'd need is the naming in the forward needs to support the HF Diffusers names

  • hidden_states
  • attention_mask
  • encoder_hidden_states

Thank you for taking the time to look at this proposal.

@kohya-ss
Copy link
Owner

kohya-ss commented Nov 19, 2023

Thank you for the update! I think the implementation is very nice and simple.

However, I still wonder it might be advantageous to leave the existing code as is and only call the AttnProcessor if and when it is set, so that some compatibility testing with the existing code can be omitted.

In addition, if we did that, the name translation might not be necessary.

Of course, I think the PR is already great and ready to merge. However, I would like to carefully test the behavior of this code for backward compatibility.

@rockerBOO
Copy link
Contributor Author

However, I still wonder it might be advantageous to leave the existing code as is and only call the AttnProcessor if and when it is set, so that some compatibility testing with the existing code can be omitted.

I went through and simplified it to the approach you hinted at here. See #961

I believe this is the simplest approach and captures all the properties for the forward arguments (for cases were people just pass **kwargs to the forward). Up to you which would be more impactful but I think #961 is better in every way. This reduces the amount of testing we will need to do as it will be segmented to only those who set the processor.

No worries either way. I'm happy with this process we went through to reduce this down.

Thank you again for your time.

@rockerBOO rockerBOO changed the base branch from main to dev November 20, 2023 02:03
@kohya-ss
Copy link
Owner

Thank you for updating, and sorry again for the delay. I agree that #961 seems to be the simplest approach.

Is it OK to merge #961 and close this without merging?

@rockerBOO
Copy link
Contributor Author

Thank you for updating, and sorry again for the delay. I agree that #961 seems to be the simplest approach.

Is it OK to merge #961 and close this without merging?

Yes, merging #961 and closing this one. I appreciate it and no worries. :)

@rockerBOO rockerBOO closed this Nov 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants