Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] SiglipVisionModel ported from transformers #6942

Merged
merged 24 commits into from
Aug 5, 2024

Conversation

ChristopherCho
Copy link
Contributor

@ChristopherCho ChristopherCho commented Jul 30, 2024

This PR implemented SiglipVisionModel for VLMs.

  • Some of the pre-trained SiglipVisionModel cannot use vLLM's Attention layer.
    Therefore, I implemented alternative attention layers if vLLM's one is impossible.
  • I tried vllm_flash_attn backend which doesn't work properly with CUDA Error.
    Thus, only the basic attention mechanism is working properly for now.
  • Modified Paligemma to use implemented SiglipVisionModel.

FIX #6941
FIX #7144

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@DarkLight1337
Copy link
Member

DarkLight1337 commented Jul 30, 2024

Just saw this. Thanks for the implementation, I'll leave the review to @ywang96 since he worked on PaliGemma.

@ywang96 ywang96 self-assigned this Jul 30, 2024
@jeejeelee
Copy link
Contributor

The calculation of attention can use the MEA (Memory Efficient Attention) ops from xformers. see: https://facebookresearch.github.io/xformers/components/ops.html

@ChristopherCho
Copy link
Contributor Author

@jeejeelee

The calculation of attention can use the MEA (Memory Efficient Attention) ops from xformers. see: https://facebookresearch.github.io/xformers/components/ops.html

Thanks! I added xformers MEA and torch sdpa to give various options.

@jeejeelee
Copy link
Contributor

Hi, thank you for your excellent work. I'd like to know if your implementation supports the following model:

import timm 
 model = timm.create_model(
                "vit_so400m_patch14_siglip_384.webli",
                pretrained=False,
                num_classes=0,
                dynamic_img_size=True,
                dynamic_img_pad=True,
            )

@ChristopherCho
Copy link
Contributor Author

@jeejeelee
Hi, I believe that the pre-trained Siglip model vit_so400m_patch14_siglip_384.webli is just the same as the one on the Huggingface.
Also, afaik, the pre-trained Paligemma model uses the same pre-trained Siglip vision encoder.

When I tried loading the pre-trained Paligemma model with the following codes, I could successfully load and infer.

import requests
from PIL import Image
from vllm import LLM, SamplingParams

model_id = "google/paligemma-3b-mix-224"
prompt = "What is on the flower?"
image_file = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg?download=true"
image = Image.open(requests.get(image_file, stream=True).raw)

llm = LLM(model=model_id)
sampling_params = SamplingParams(
    temperature=0.0
)

input_dict = {
    "prompt": prompt,
    "multi_modal_data": {
        "image": image,
    }
}
outputs = llm.generate(input_dict, sampling_params)

print(outputs[0].outputs[0].text)

So, as the pre-trained Paligemma works fine, the pre-trained Siglip vision model should work.
But there are some caveats.

  1. As the current version of the code is only for supporting VLMs, it does not contain any textual part of the Siglip model (e.g. SiglipTextTransformer, SiglipTextModel, etc.). Therefore, the full Siglip model cannot be loaded with this code.
  2. The aforementioned Siglip model is the exact case in which you cannot use vLLM's attention. That is because the head size is 72 (the hidden size of the model is 1152 and the number of heads is 16), which is not in the supported head sizes.
    As I implemented the fallback logic for this case, it works fine but does not use the vLLM's paged attention.

@ywang96
Copy link
Member

ywang96 commented Aug 1, 2024

@jeejeelee Hi, I believe that the pre-trained Siglip model vit_so400m_patch14_siglip_384.webli is just the same as the one on the Huggingface. Also, afaik, the pre-trained Paligemma model uses the same pre-trained Siglip vision encoder.

When I tried loading the pre-trained Paligemma model with the following codes, I could successfully load and infer.

import requests
from PIL import Image
from vllm import LLM, SamplingParams

model_id = "google/paligemma-3b-mix-224"
prompt = "What is on the flower?"
image_file = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg?download=true"
image = Image.open(requests.get(image_file, stream=True).raw)

llm = LLM(model=model_id)
sampling_params = SamplingParams(
    temperature=0.0
)

input_dict = {
    "prompt": prompt,
    "multi_modal_data": {
        "image": image,
    }
}
outputs = llm.generate(input_dict, sampling_params)

print(outputs[0].outputs[0].text)

So, as the pre-trained Paligemma works fine, the pre-trained Siglip vision model should work. But there are some caveats.

  1. As the current version of the code is only for supporting VLMs, it does not contain any textual part of the Siglip model (e.g. SiglipTextTransformer, SiglipTextModel, etc.). Therefore, the full Siglip model cannot be loaded with this code.
  2. The aforementioned Siglip model is the exact case in which you cannot use vLLM's attention. That is because the head size is 72 (the hidden size of the model is 1152 and the number of heads is 16), which is not in the supported head sizes.
    As I implemented the fallback logic for this case, it works fine but does not use the vLLM's paged attention.

Hey @ChristopherCho! Thank you very much for this PR and I really appreciate it - will review this tonight!

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ChristopherCho - Thank you very much for the PR!

I took a first pass and left some comments. Mostly I'm wondering if we should really use vLLM's attention module in the ViT when it's only used once per sequence, and I suggest simply using the attention modules from transformers for now.

Comment on lines 262 to 268
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scale,
cache_config=cache_config,
quant_config=quant_config,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently for ClipVisionModel, we don't use the vLLM internal Attention since the ViT encoder only runs once at the prefill time per sequence, thus I don't think there's much value leveraging a KV cache for this.

Have you seen a significant performance speedup using vLLM Attention compared to transformers Attention? If not, I think we'd rather just use the one from transformers for simplicity for now since this is not the major bottleneck in the whole inference pipeline.

Copy link
Contributor Author

@ChristopherCho ChristopherCho Aug 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, there weren't significant improvements by using vLLM Attention.
I believe it is due to the reason that you mentioned. It does not leverage the advantages of using a KV cache.
I removed the vLLM Attention part, but keep the log at bb570c3 for the future.

Comment on lines 311 to 310
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert this for now - if we're going to apply TP on the vision tower, we should do it in another separate PR with CLIPVisionModel together. Ideally, we should not apply a infrastructure change to only one model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted via dee55d0

Comment on lines 441 to 446
SIGLIP_ATTENTION_CLASSES = {
"eager": SiglipAttention,
"flash_attention_2": SiglipFlashAttention2,
"sdpa": SiglipSdpaAttention,
"xformers": SiglipxFormersAttention,
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really appreciate that you went out and implemented these (regardless if we're going to use them or not)!

vllm/model_executor/models/siglip.py Outdated Show resolved Hide resolved
vllm/model_executor/models/siglip.py Outdated Show resolved Hide resolved
@ywang96 ywang96 mentioned this pull request Aug 2, 2024
ChristopherCho and others added 2 commits August 5, 2024 10:01
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
@ywang96
Copy link
Member

ywang96 commented Aug 5, 2024

@ChristopherCho Now that we merged #7020 - I think there's some benefit of enabling TP for this model, given that SigLip is 400M parameters in 3B PaliGemma model.

However, could you take a look at the implementation of attention here instead of using the vLLM attention? The latter creates kv cache that I don't think we should be using for ViT when it's only run at the prefill stage.

@ChristopherCho
Copy link
Contributor Author

@ywang96 I've removed vLLM Attention for now and utilized the basic attention mechanism, which is just the same as the SiglipAttention.

By the way, I've found that you mentioned the xformers MEA which is implemented here.
Would it be better to utilize this as the baseline attention mechanism? I've tested both the basic attention mechanism and the MEA and found that the MEA was a little bit slower in my environment.

@ywang96
Copy link
Member

ywang96 commented Aug 5, 2024

@ywang96 I've removed vLLM Attention for now and utilized the basic attention mechanism, which is just the same as the SiglipAttention.

By the way, I've found that you mentioned the xformers MEA which is implemented here. Would it be better to utilize this as the baseline attention mechanism? I've tested both the basic attention mechanism and the MEA and found that the MEA was a little bit slower in my environment.

Let's use default MHA implementation for this PR: I think if you use MEA then we need to necessarily TP the attention block (since it's using the vLLM TP layers). We can leave a TODO here and do the TP in a later PR!

@ChristopherCho
Copy link
Contributor Author

@ywang96

Okay, I've implemented the code to use transformers SiglipAttention for now, but keep TP versions for TODO in this PR.
By the way, thank you for your comments!

@ywang96
Copy link
Member

ywang96 commented Aug 5, 2024

/ready

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 5, 2024
@ywang96
Copy link
Member

ywang96 commented Aug 5, 2024

Overall LGTM! I will just need to test this PR locally myself to make sure everything works fine!

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChristopherCho I've made a few more changes to this PR afterwards and verified it works with both TP=1 and TP>1. Thank you again for working on this!

@ywang96 ywang96 enabled auto-merge (squash) August 5, 2024 05:17
@ywang96 ywang96 merged commit c0d8f16 into vllm-project:main Aug 5, 2024
67 checks passed
sfc-gh-mkeralapura pushed a commit to sfc-gh-mkeralapura/vllm that referenced this pull request Aug 12, 2024
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Aug 17, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
Co-authored-by: Roger Wang <ywang@roblox.com>
Signed-off-by: Alvant <alvasian@yandex.ru>
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
4 participants