You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I see in the provided example checkpoint, max_attn_resolution is set to be 16. So during encoding, the image will go through downblocks of 64x64, 32x32, 16x16, 16x16, cross_attn is added twice after 16x16 downblocks. Yet during decoding the image will go through 16x16, 32x32, 64x64, 64x64, and cross_attn is added only once, is this an expected behavior(resulting in asymmetric encoder and decoder structure)?
The text was updated successfully, but these errors were encountered:
Hi, thank you for your interest in our work!
You are correct. There are indeed two cross-attention blocks for the encoders but only one for the decoders. This wasn't an intentional design choice. Initially, the cross-attention mechanism was supposed to be applied to multi-scale features. However, I set the max_attn_resolution to 16 mainly to save memory usage. Despite this, the current architecture performs well in practice. I will conduct experiments with more cross-attention blocks (e.g., setting max_attn_resolution to 32) to see if this can further improve performance. Thank you for pointing this out to my attention!
I see in the provided example checkpoint, max_attn_resolution is set to be 16. So during encoding, the image will go through downblocks of 64x64, 32x32, 16x16, 16x16, cross_attn is added twice after 16x16 downblocks. Yet during decoding the image will go through 16x16, 32x32, 64x64, 64x64, and cross_attn is added only once, is this an expected behavior(resulting in asymmetric encoder and decoder structure)?
The text was updated successfully, but these errors were encountered: