Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Qwen2-VL's multimodal RoPE implementation #384

Merged
merged 3 commits into from
Nov 19, 2024

Conversation

li-plus
Copy link
Collaborator

@li-plus li-plus commented Nov 15, 2024

Summary

Support Qwen2-VL's multimodal RoPE kernel. See original implementation here: https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L203-L245

Finished the TODO left in #175. Complete feature request #165.

Testing Done

  • Hardware Type: A800-SXM4-80GB
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@tyler-romero
Copy link
Collaborator

Very nice!

Can you update the convergence tests for qwen2_vl to include your RoPE implementation as well? Right now there is a line excluding rope for qwen2_vl models specifically.

Also are you available to add a benchmark for your kernel? https://github.com/linkedin/Liger-Kernel/blob/main/docs/CONTRIBUTING.md#adding-a-new-kernel

@ByronHsu
Copy link
Collaborator

Thank you for making the non-trivial contribution!

  1. Can we rebase Qwen2-VL Bug / Incompatibility Fixes #388 to properly test convergence?
  2. I notice there is another apply_rotary_pos_emb_vision, should we implement and patch that too?
  3. Are you on wechat? Very impressed by your contribution and want to discuss more with you. My id is wxid_nn8pbmlh9ae712

@li-plus
Copy link
Collaborator Author

li-plus commented Nov 18, 2024

Very nice!

Can you update the convergence tests for qwen2_vl to include your RoPE implementation as well? Right now there is a line excluding rope for qwen2_vl models specifically.

Also are you available to add a benchmark for your kernel? https://github.com/linkedin/Liger-Kernel/blob/main/docs/CONTRIBUTING.md#adding-a-new-kernel

@tyler-romero The Qwen2VL convergence tests are already fixed in #388. I've updated them to enable M-RoPE kernel injection for Qwen2VL. Also added benchmark scripts for M-RoPE in the latest commit. Benchmark results on A800 are visualized below:

image image

You may test it on A100 and update all_benchmark_data.csv. I don't have A100 at hand.

@li-plus
Copy link
Collaborator Author

li-plus commented Nov 18, 2024

Thank you for making the non-trivial contribution!

  1. Can we rebase Qwen2-VL Bug / Incompatibility Fixes #388 to properly test convergence?
  2. I notice there is another apply_rotary_pos_emb_vision, should we implement and patch that too?
  3. Are you on wechat? Very impressed by your contribution and want to discuss more with you. My id is wxid_nn8pbmlh9ae712

@ByronHsu Hi,

  1. I've already rebased onto the latest master including this commit Qwen2-VL Bug / Incompatibility Fixes #388 and enabled RoPE for Qwen2VL.
  2. The current apply_rotary_pos_emb_vision implementation is inefficient since it recomputes cos & sin for q & k. It's better to optimize it in modeling_qwen2_vl.py in upstream transformers to apply RoPE on both q & k at the same time. Then we can reuse the RoPE triton kernel of llama / mistral.
  3. Thanks for the invitation. I'm on wechat but I couldn't find your account based on wxid. Maybe you could email me the QR code?

@ByronHsu ByronHsu merged commit cc5561e into linkedin:main Nov 19, 2024
1 check passed
@ByronHsu
Copy link
Collaborator

Thank you @li-plus !!

@li-plus li-plus deleted the qwen2vl-mrope branch November 19, 2024 06:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants