Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Support resizing rectangular pos_embed #2190

Open
gau-nernst opened this issue May 31, 2024 · 2 comments
Open

[FEATURE] Support resizing rectangular pos_embed #2190

gau-nernst opened this issue May 31, 2024 · 2 comments
Labels
enhancement New feature or request

Comments

@gau-nernst
Copy link
Contributor

Is your feature request related to a problem? Please describe.

Currently when changing ViT img size from a rectangular size, resample_abs_pos_embed() does not work correctly since it does not know the original rectangular size and assume a square.

v = resample_abs_pos_embed(
v,
new_size=model.patch_embed.grid_size,
num_prefix_tokens=num_prefix_tokens,
interpolation=interpolation,
antialias=antialias,
verbose=True,
)

if old_size is None:
hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
old_size = hw, hw

Describe the solution you'd like

It should work out of the box.

Describe alternatives you've considered

Manually resize it.

Additional context

Apparently dynamic img size also will not work when original img size is rectangle.

if self.dynamic_img_size:
B, H, W, C = x.shape
pos_embed = resample_abs_pos_embed(
self.pos_embed,
(H, W),
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
)

This is a rare problem since most image ViT use square inputs. The particular model I'm using is my previously ported AudioMAE (https://huggingface.co/gaunernst/vit_base_patch16_1024_128.audiomae_as2m), which uses rectangular input (mel-spectrogram).

I understand it is not so straight-forward to support this, since once the model is created (with updated image size), the original image size is lost. Some hacks can probably bypass this, but not so nice

  1. Propagate the original image size to the _load_weights() function
  2. Create a model with the original image size, load weights as usual. Add a new method like .set_img_size() which will update the internal img_size attribute and resamle pos embed.

Perhaps an easier solution is to fix dynamic img size to pass the original img size (which I tested locally and works)

        if self.dynamic_img_size:
            B, H, W, C = x.shape
            pos_embed = resample_abs_pos_embed(
                self.pos_embed,
                (H, W),
                self.patch_embed.grid_size,
                num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
            )
@gau-nernst gau-nernst added the enhancement New feature or request label May 31, 2024
@rwightman
Copy link
Collaborator

@gau-nernst I have thought about this, and yeah it's not done because to do it at weight load time where it's done now, it's more complexity than I'd like, especially for the benefit/demand.

That said, I need a set_img_size() like fn for another project, so it will be supported through that mechanism. I was planning to do something similar to the swin_v2_cr implementation (https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/swin_transformer_v2_cr.py#L324-L337) but tweak naming a bit...

I believe dynamic resize can be made to work, it needs the 'old size' arg to be non-default...

@rwightman
Copy link
Collaborator

@gau-nernst on PR #2225 there is first pass at implementing a set_input_size fn... currently should mostly work for models from vision_transformer.py, vision_transformer_hybrid.py, and swin_transformer.py (v1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants