Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Patch for CLIP #21

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

FrancescoSaverioZuppichini

Hi 👋

Thanks for the amazing work on ToMe. I am trying to create a patch for CLIP. The main issue is that I cannot use PyTorch optimized attention implementation because I cannot edit the source code to weight the attention matrix with the log of the size. This results in slower forward pass.

This is the code I've used to benchmark the patch

import torch

from tome.patch.clip import ToMeAttention, apply_patch

torch.manual_seed(0)

import clip
import torch
from torch.utils import benchmark

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
import time


def profile_model(fn, min_run_time=2):
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()
    # warmup
    for _ in range(4):
        fn()
    res = benchmark.Timer(
        stmt="fn()", globals={"fn": fn}, label="profile", sub_label="", description=""
    ).blocked_autorange(min_run_time=min_run_time)
    torch.cuda.synchronize()
    memory = torch.cuda.max_memory_allocated() / 2**20
    memory = f"Memory used: {memory} MB"
    print(res)
    print(memory)


with torch.no_grad():
    x = torch.randn((1, 3, 224, 224)).cuda()
    model.visual = model.visual.float()

    profile_model(lambda: model.visual(x))

    apply_patch(model.visual)
    model.visual.r = 8

    profile_model(lambda: model.visual(x))

Resulting in

Original CLIP

profile
  Median: 4.86 ms
  IQR:    0.18 ms (4.80 to 4.98)
  5 measurements, 100 runs per measurement, 1 thread
Memory used: 507.9794921875 MB

ToMe CLIP

profile
  Median: 9.85 ms
  IQR:    0.44 ms (9.51 to 9.95)
  21 measurements, 10 runs per measurement, 1 thread

Any idea how to use ToMe when nn.MultiHead or anything else (like stuff from xformers) is used?

From my benchmarks it looks like somebody would be better off to just use the build in implementation in torch than to use ToMe

Thanks a lot,

Fra

@facebook-github-bot
Copy link

Hi @FrancescoSaverioZuppichini!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

1 similar comment
@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@dbolya
Copy link
Contributor

dbolya commented Mar 12, 2023

Hi thanks for trying this out. Yeah the native timm implementation for Attention is not very efficient.

There are two things we need from attn:

  1. The key values (k) to use as metrics for merging
  2. To add sizes to the pre (or post) softmax attn matrix

The problem, as you've discovered, is that getting / setting those values from nn.MultiHeadAttn is very much non-trivial, since it computes K and applies the softmax internally.

For 1, we have the in_proj_weight matrix (which includes W_k), so we could just re-compute the keys ourselves (which would take extra time, but it's not the end of the world). For 2, there may be a hacky way to add what we want to the pre-softmax matrix by modifying the biases (i.e., in_proj_bias or k_bias, idk which) on the fly.

Neither 1 or 2 are strict requirements, though, but of course they make the result more accurate.

First, I would try what happens if you use nn.MultiHeadAttn without 1 or 2. Simply call the normal attention module (not ToMeAttn) and then use the input x as the metric (instead of k). Once that gives a speed improvement, then we can start thinking of making it more accurate by incorporating 1 or 2.

@AlexKoff88
Copy link

Hi, I used this implementation partially fixed some issues and addressed performance problem that comes from an unnecessary permutation of tensor dimensions. I applied it to OpenCLIP models. Please find my implementation here. Any feedback is appreciated.

@FrancescoSaverioZuppichini
Copy link
Author

Hi, I used this implementation partially fixed some issues and addressed performance problem that comes from an unnecessary permutation of tensor dimensions. I applied it to OpenCLIP models. Please find my implementation here. Any feedback is appreciated.

is it faster?

@AlexKoff88
Copy link

Hi, I used this implementation partially fixed some issues and addressed performance problem that comes from an unnecessary permutation of tensor dimensions. I applied it to OpenCLIP models. Please find my implementation here. Any feedback is appreciated.

is it faster?

yes, it is especially if you run inference on CPU where there is not such much compute.

@dbolya
Copy link
Contributor

dbolya commented May 10, 2023

Hi @AlexKoff88, sounds great! Do you know how the zero-shot accuracy of your implementation compares? E.g., on imagenet val.

@AlexKoff88
Copy link

Hi, it depends on the number of tokens you merge from block to block. I found that the accuracy degrades significantly on COCO Captions if I want to achieve 2x speedup. My hope is that some lightweight fine-tuning can help here. Working on it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants