From b5d8cbf90d338ea2eda4e2e1863dcf0722599197 Mon Sep 17 00:00:00 2001
From: Tyler Romero <tyler.alexander.romero@gmail.com>
Date: Sun, 8 Sep 2024 14:14:45 -0700
Subject: [PATCH] Monkeypatch for Qwen2-VL (#175)

## Summary
Monkeypatch for the recently-published
[Qwen2-VL](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
HF `transformers` modeling code:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py

Feature Request: https://github.com/linkedin/Liger-Kernel/issues/165

## Details
Qwen2-VL in `transformers` is available on `transformers` main but is
yet to be published in a release.

## Testing Done
- Hardware Type: 4090
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence

---------

Co-authored-by: Shao Tang <tangshao28@gmail.com>
---
 .gitignore                                    |   5 +-
 README.md                                     |   3 +-
 src/liger_kernel/transformers/__init__.py     |   1 +
 .../transformers/model/qwen2_vl.py            | 172 ++++++++++
 src/liger_kernel/transformers/monkey_patch.py |  50 +++
 test/convergence/test_mini_models.py          |   3 +-
 .../test_mini_models_multimodal.py            | 310 ++++++++++++++++++
 .../convergence/test_mini_models_no_logits.py | 103 +++++-
 test/transformers/test_monkey_patch.py        |   1 +
 test/utils.py                                 |  23 ++
 10 files changed, 667 insertions(+), 4 deletions(-)
 create mode 100644 src/liger_kernel/transformers/model/qwen2_vl.py
 create mode 100644 test/convergence/test_mini_models_multimodal.py

diff --git a/.gitignore b/.gitignore
index e643ae280..c84380ea4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -12,5 +12,8 @@ site/
 build/
 dist/
 
+# Lockfiles
+uv.lock
+
 # Benchmark images
-benchmark/visualizations
+benchmark/visualizations
\ No newline at end of file
diff --git a/README.md b/README.md
index 6d11f99c8..765dbec33 100644
--- a/README.md
+++ b/README.md
@@ -91,7 +91,7 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
 - **Exact:** Computation is exact—no approximations! Both forward and backward passes are implemented with rigorous unit tests and undergo convergence testing against training runs without Liger Kernel to ensure accuracy.
 - **Lightweight:** Liger Kernel has minimal dependencies, requiring only Torch and Triton—no extra libraries needed! Say goodbye to dependency headaches!
 - **Multi-GPU supported:** Compatible with multi-GPU setups (PyTorch FSDP, DeepSpeed, DDP, etc.).
-- **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860) 
+- **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860)
 
 ## Target Audiences
 
@@ -227,6 +227,7 @@ loss.backward()
 | Gemma1      | `liger_kernel.transformers.apply_liger_kernel_to_gemma`    | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy         |
 | Gemma2      | `liger_kernel.transformers.apply_liger_kernel_to_gemma2`   | RoPE, RMSNorm, GeGLU, CrossEntropyLoss         |
 | Qwen2       | `liger_kernel.transformers.apply_liger_kernel_to_qwen2`    | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy        |
+| Qwen2-VL       | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl`    | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy        |
 | Phi3        | `liger_kernel.transformers.apply_liger_kernel_to_phi3`     | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy         |
 
 
diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py
index 5c44a154b..147948b18 100644
--- a/src/liger_kernel/transformers/__init__.py
+++ b/src/liger_kernel/transformers/__init__.py
@@ -15,6 +15,7 @@
     apply_liger_kernel_to_mixtral,
     apply_liger_kernel_to_phi3,
     apply_liger_kernel_to_qwen2,
+    apply_liger_kernel_to_qwen2_vl,
 )
 from liger_kernel.transformers.rms_norm import LigerRMSNorm  # noqa: F401
 from liger_kernel.transformers.rope import liger_rotary_pos_emb  # noqa: F401
diff --git a/src/liger_kernel/transformers/model/qwen2_vl.py b/src/liger_kernel/transformers/model/qwen2_vl.py
new file mode 100644
index 000000000..cfb7a905b
--- /dev/null
+++ b/src/liger_kernel/transformers/model/qwen2_vl.py
@@ -0,0 +1,172 @@
+from typing import List, Optional, Tuple, Union
+
+import torch
+from torch.nn import CrossEntropyLoss
+from transformers.models.qwen2_vl.modeling_qwen2_vl import (
+    _CONFIG_FOR_DOC,
+    QWEN2_VL_INPUTS_DOCSTRING,
+    Qwen2VLCausalLMOutputWithPast,
+)
+from transformers.utils import (
+    add_start_docstrings_to_model_forward,
+    replace_return_docstrings,
+)
+
+from liger_kernel.transformers.fused_linear_cross_entropy import (
+    LigerFusedLinearCrossEntropyLoss,
+)
+
+
+@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
+@replace_return_docstrings(
+    output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
+)
+def lce_forward(
+    self,
+    input_ids: torch.LongTensor = None,
+    attention_mask: Optional[torch.Tensor] = None,
+    position_ids: Optional[torch.LongTensor] = None,
+    past_key_values: Optional[List[torch.FloatTensor]] = None,
+    inputs_embeds: Optional[torch.FloatTensor] = None,
+    labels: Optional[torch.LongTensor] = None,
+    use_cache: Optional[bool] = None,
+    output_attentions: Optional[bool] = None,
+    output_hidden_states: Optional[bool] = None,
+    return_dict: Optional[bool] = None,
+    pixel_values: Optional[torch.Tensor] = None,
+    pixel_values_videos: Optional[torch.FloatTensor] = None,
+    image_grid_thw: Optional[torch.LongTensor] = None,
+    video_grid_thw: Optional[torch.LongTensor] = None,
+    rope_deltas: Optional[torch.LongTensor] = None,
+) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
+    r"""
+    Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
+
+    Args:
+        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+    Returns:
+
+    Example:
+
+    ```python
+    >>> from PIL import Image
+    >>> import requests
+    >>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
+
+    >>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
+    >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
+
+    >>> messages = [
+        {
+            "role": "user",
+            "content": [
+                {"type": "image"},
+                {"type": "text", "text": "What is shown in this image?"},
+            ],
+        },
+    ]
+    >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
+    >>> image = Image.open(requests.get(url, stream=True).raw)
+
+    >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+    >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
+
+    >>> # Generate
+    >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+    >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+    "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
+    ```"""
+
+    output_attentions = (
+        output_attentions
+        if output_attentions is not None
+        else self.config.output_attentions
+    )
+    output_hidden_states = (
+        output_hidden_states
+        if output_hidden_states is not None
+        else self.config.output_hidden_states
+    )
+    return_dict = (
+        return_dict if return_dict is not None else self.config.use_return_dict
+    )
+
+    if inputs_embeds is None:
+        inputs_embeds = self.model.embed_tokens(input_ids)
+        if pixel_values is not None:
+            pixel_values = pixel_values.type(self.visual.get_dtype())
+            image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).to(
+                inputs_embeds.device
+            )
+            image_mask = input_ids == self.config.image_token_id
+            if self.training:
+                inputs_embeds = inputs_embeds.clone()
+            inputs_embeds[image_mask] = image_embeds
+        if pixel_values_videos is not None:
+            pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
+            video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw).to(
+                inputs_embeds.device
+            )
+            video_mask = input_ids == self.config.video_token_id
+            inputs_embeds[video_mask] = video_embeds
+        if attention_mask is not None:
+            attention_mask = attention_mask.to(inputs_embeds.device)
+
+    outputs = self.model(
+        input_ids=None,
+        position_ids=position_ids,
+        attention_mask=attention_mask,
+        past_key_values=past_key_values,
+        inputs_embeds=inputs_embeds,
+        use_cache=use_cache,
+        output_attentions=output_attentions,
+        output_hidden_states=output_hidden_states,
+        return_dict=return_dict,
+    )
+
+    hidden_states = outputs[0]
+
+    loss = None
+    logits = None
+
+    if self.training and (labels is not None):
+        shift_hidden_states = hidden_states[..., :-1, :].contiguous()
+        shift_labels = labels[..., 1:].contiguous()
+
+        # Flatten tokens
+        shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
+        shift_labels = shift_labels.view(-1)
+
+        lce = LigerFusedLinearCrossEntropyLoss()
+        loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
+    else:
+        logits = self.lm_head(hidden_states)
+        logits = logits.float()
+        if labels is not None:
+            # Shift so that tokens < n predict n
+            shift_logits = logits[..., :-1, :].contiguous()
+            shift_labels = labels[..., 1:].contiguous()
+            # Flatten the tokens
+            loss_fct = CrossEntropyLoss()
+            shift_logits = shift_logits.view(-1, self.config.vocab_size)
+            shift_labels = shift_labels.view(-1)
+            # Enable model parallelism
+            shift_labels = shift_labels.to(shift_logits.device)
+            loss = loss_fct(shift_logits, shift_labels)
+
+    if not return_dict:
+        output = (logits,) + outputs[1:]
+        return (loss,) + output if loss is not None else output
+
+    return Qwen2VLCausalLMOutputWithPast(
+        loss=loss,
+        logits=logits,
+        past_key_values=outputs.past_key_values,
+        hidden_states=outputs.hidden_states,
+        attentions=outputs.attentions,
+        rope_deltas=rope_deltas,
+    )
diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py
index 6ecb670ef..1cca9753c 100644
--- a/src/liger_kernel/transformers/monkey_patch.py
+++ b/src/liger_kernel/transformers/monkey_patch.py
@@ -4,6 +4,7 @@
 
 from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
 from liger_kernel.transformers.geglu import LigerGEGLUMLP
+from liger_kernel.transformers.layer_norm import LigerLayerNorm
 from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
 from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
 from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
@@ -245,6 +246,54 @@ def apply_liger_kernel_to_qwen2(
         modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
 
 
+def apply_liger_kernel_to_qwen2_vl(
+    cross_entropy: bool = False,
+    fused_linear_cross_entropy: bool = True,
+    rms_norm: bool = True,
+    layer_norm: bool = True,
+    swiglu: bool = True,
+) -> None:
+    """
+    Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
+    NOTE: Qwen2-VL is not available in transformers<=4.44.2
+
+    Args:
+        cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
+        fused_linear_cross_entropy (bool):
+            Whether to apply Liger's fused linear cross entropy loss. Default is True.
+            `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
+            If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
+        rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
+        layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
+        swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
+    """
+    assert not (
+        cross_entropy and fused_linear_cross_entropy
+    ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
+
+    from transformers.models.qwen2_vl import modeling_qwen2_vl
+
+    from liger_kernel.transformers.model.qwen2_vl import (
+        lce_forward as qwen2_vl_lce_forward,
+    )
+
+    # TODO: Support Qwen2-VL's multimodal RoPE implementation
+
+    if rms_norm:
+        # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
+        modeling_qwen2_vl.Qwen2RMSNorm = partial(
+            LigerRMSNorm, init_fn="ones", casting_mode="gemma"
+        )
+    if layer_norm:
+        modeling_qwen2_vl.LayerNorm = LigerLayerNorm
+    if cross_entropy:
+        modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
+    if fused_linear_cross_entropy:
+        modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
+    if swiglu:
+        modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
+
+
 def apply_liger_kernel_to_phi3(
     rope: bool = True,
     cross_entropy: bool = False,
@@ -291,6 +340,7 @@ def apply_liger_kernel_to_phi3(
     "mistral": apply_liger_kernel_to_mistral,
     "mixtral": apply_liger_kernel_to_mixtral,
     "qwen2": apply_liger_kernel_to_qwen2,
+    "qwen2_vl": apply_liger_kernel_to_qwen2_vl,
     "phi3": apply_liger_kernel_to_phi3,
 }
 
diff --git a/test/convergence/test_mini_models.py b/test/convergence/test_mini_models.py
index 95c832e15..a2b6f59ef 100644
--- a/test/convergence/test_mini_models.py
+++ b/test/convergence/test_mini_models.py
@@ -331,6 +331,7 @@ def run_mini_model(
 
     for i in range(num_steps):
         batch = next(loader_iter).to(model.device)
+        optimizer.zero_grad()
         output = model(**batch)
         output.loss.backward()
         optimizer.step()
@@ -343,7 +344,7 @@ def run_mini_model(
 @pytest.mark.parametrize(
     "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
     [
-        # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
+        # Gemma 1 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
         ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 6e-4, 5e-3, 1e-5, 5e-3, 1e-5),
         pytest.param(
             "mini_gemma1",
diff --git a/test/convergence/test_mini_models_multimodal.py b/test/convergence/test_mini_models_multimodal.py
new file mode 100644
index 000000000..9d77d43e6
--- /dev/null
+++ b/test/convergence/test_mini_models_multimodal.py
@@ -0,0 +1,310 @@
+import functools
+import os
+from test.utils import (
+    UNTOKENIZED_DATASET_PATH,
+    MiniModelConfig,
+    assert_verbose_allclose,
+    multimodal_collate_fn,
+    set_seed,
+    supports_bfloat16,
+)
+
+import pytest
+import torch
+from datasets import load_dataset
+from torch.utils.data import DataLoader
+from transformers.models.auto.processing_auto import AutoProcessor
+
+from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
+
+try:
+    # Qwen2-VL is only available in transformers>4.44.2
+    from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig
+    from transformers.models.qwen2_vl.modeling_qwen2_vl import (
+        Qwen2VLForConditionalGeneration,
+    )
+
+    QWEN2_VL_AVAILABLE = True
+except ImportError:
+    QWEN2_VL_AVAILABLE = False
+
+torch.use_deterministic_algorithms(True)
+
+#  Only setting torch.use_deterministic_algorithms(True) throws the following error:
+#  RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`,
+#  but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an
+#  environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. For more information,
+#  go to https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility
+
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+
+TEST_IMAGE_DIM = 64
+
+MINI_MODEL_SETUPS = {}
+
+if QWEN2_VL_AVAILABLE:
+    MINI_MODEL_SETUPS["mini_qwen2_vl"] = MiniModelConfig(
+        liger_kernel_patch_func=functools.partial(
+            apply_liger_kernel_to_qwen2_vl, fused_linear_cross_entropy=False
+        ),
+        model_class=Qwen2VLForConditionalGeneration,
+        mini_model_config=Qwen2VLConfig(
+            attention_dropout=0.0,
+            # Token Ids and vocab size must match those in the tokenizer/processor
+            # https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/blob/main/config.json
+            bos_token_id=151643,
+            eos_token_id=151645,
+            vision_start_token_id=151652,
+            vision_end_token_id=151653,
+            vision_token_id=151654,
+            image_token_id=151655,
+            hidden_act="silu",
+            hidden_size=1024,  # 8192
+            initializer_range=0.02,
+            intermediate_size=1024,  # 29568
+            max_position_embeddings=32768,
+            max_window_layers=4,  # 80
+            num_attention_heads=8,  # 64
+            num_hidden_layers=4,  # 80
+            num_key_value_heads=2,  # 8
+            rms_norm_eps=1e-6,  # 1e-5
+            rope_theta=1000000.0,
+            rope_scaling=dict(
+                type="mrope",
+                mrope_section=[16, 24, 24],  # (temporal, height, width)
+            ),
+            sliding_window=4096,
+            tie_word_embeddings=True,
+            use_cache=False,  # True
+            vocab_size=152064,
+            use_sliding_window=False,
+            vision_config={
+                "depth": 4,  # 32
+                "embed_dim": 128,  # 1280
+                "mlp_ratio": 1,
+                "num_heads": 8,  # 16
+                "in_chans": 3,
+                "hidden_size": 1024,  # 1536
+            },
+            attn_implementation="sdpa",
+        ),
+    )
+
+
+def create_processor(model_name):
+    if model_name == "mini_qwen2_vl":
+        return AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
+    else:
+        raise ValueError(f"Processor not available for model {model_name}")
+
+
+def create_multimodal_dataset(model_name: str):
+    processor = create_processor(model_name)
+
+    def generate_procedural_image(example, index):
+        """Generate an image with a single row of white pixels at the index specified"""
+        image = torch.zeros(3, TEST_IMAGE_DIM, TEST_IMAGE_DIM)
+        image[:, index % TEST_IMAGE_DIM, :] = 255
+        example["image"] = image
+        return example
+
+    def apply_chat_template(example):
+        """
+        Under the hood, this inserts the correct image placeholder token into the text.
+        More or less this conversation format is used by HF's mllms. The fact that it is
+        formatting as for IFT is not in-and-of-itself important here.
+        """
+        conversation = [
+            {
+                "role": "user",
+                "content": [
+                    {"type": "image"},
+                    {"type": "text", "text": "Describe this image."},
+                ],
+            },
+            {
+                "role": "assistant",
+                "content": [{"type": "text", "text": example["text"]}],
+            },
+        ]
+        example["text"] = processor.apply_chat_template(conversation, tokenize=False)
+        return example
+
+    def preprocess_function(examples):
+        """Tokenize text, preprocess images, and generate other relevant inputs for the model."""
+        return processor(
+            text=examples["text"],
+            images=examples["image"],
+            padding="max_length",
+            truncation=True,
+            max_length=1024,  # longer than for text-only b/c images require quite a few tokens
+        )
+
+    train_dataset = (
+        load_dataset(
+            "text", data_files={"train": UNTOKENIZED_DATASET_PATH}, split="train"
+        )
+        .to_iterable_dataset()  # only map examples as-needed and on-demand
+        .map(generate_procedural_image, with_indices=True)
+        .map(apply_chat_template)
+        .map(preprocess_function, remove_columns=["text", "image"])
+    )
+    return train_dataset
+
+
+def create_model(model_name):
+    """
+    Create a mini version model
+    The commented values are the original values
+    """
+    model_config = MINI_MODEL_SETUPS[model_name].mini_model_config
+    model_class = MINI_MODEL_SETUPS[model_name].model_class
+    return model_class(model_config)
+
+
+def run_mini_model_multimodal(
+    model_name="mini_qwen2_vl",
+    num_steps=100,
+    dtype=torch.bfloat16,
+    lr=1e-5,
+    with_liger=False,
+):
+    # If we move it to the beginning of test_mini_model, the two runs are initialized with different weights.
+    # This is due to RNG (Random Number Generator). The formula of RNG progression is x_(n+1) = (a * x_n + c) % m
+    # Everytime RNG is used, like randomly initialzing weight, the RNG progresses to the next state.
+    # Therefore, we have to reset RNG before we create the model to ensure the weight initialization started from the same RNG state.
+
+    set_seed(42)
+
+    if with_liger is True:
+        kwargs = {
+            "rms_norm": True,
+            "cross_entropy": True,
+        }
+        model_supports_rope = "qwen2_vl" not in model_name
+        if model_supports_rope:
+            kwargs["rope"] = True
+
+        model_supports_layer_norm = "qwen2_vl" in model_name
+        if model_supports_layer_norm:
+            kwargs["layer_norm"] = True
+
+        if "gemma" in model_name:
+            kwargs["geglu"] = True
+        else:
+            kwargs["swiglu"] = True
+        MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs)
+
+    model = create_model(model_name).to(dtype).to("cuda")
+    model.gradient_checkpointing_enable()
+
+    train_dataset = create_multimodal_dataset(model_name)
+    loader = DataLoader(
+        train_dataset, batch_size=2, shuffle=False, collate_fn=multimodal_collate_fn
+    )
+    loader_iter = iter(loader)
+    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
+
+    loss_list = []
+
+    for i in range(num_steps):
+        batch = next(loader_iter).to(model.device)
+        optimizer.zero_grad()
+        output = model(**batch)
+        output.loss.backward()
+        optimizer.step()
+
+        print(f"Step {i}, Loss: {output.loss.item()}")
+        loss_list.append(output.loss.item())
+
+    return {"loss": loss_list, "logits": output.logits, "model": model}
+
+
+@pytest.mark.parametrize(
+    "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
+    [
+        pytest.param(
+            "mini_qwen2_vl",
+            32,
+            1e-4,
+            torch.float32,
+            1e-8,
+            1e-5,
+            5e-3,
+            1e-5,
+            5e-3,
+            1e-5,
+            marks=pytest.mark.skipif(
+                not QWEN2_VL_AVAILABLE,
+                reason="Qwen2-VL not available in this version of transformers",
+            ),
+        ),
+        pytest.param(
+            "mini_qwen2_vl",
+            32,
+            1e-4,
+            torch.bfloat16,
+            1e-8,
+            1e-5,
+            1e-2,
+            1e-5,
+            1e-2,
+            1e-5,
+            marks=[
+                pytest.mark.skipif(
+                    not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
+                ),
+                pytest.mark.skipif(
+                    not QWEN2_VL_AVAILABLE,
+                    reason="Qwen2-VL not available in this version of transformers",
+                ),
+            ],
+        ),
+    ],
+)
+def test_mini_model_multimodal(
+    model_name,
+    num_steps,
+    lr,
+    dtype,
+    loss_atol,
+    loss_rtol,
+    logits_atol,
+    logits_rtol,
+    param_atol,
+    param_rtol,
+):
+    # Non-liger models should be initialized and tested first to avoid the module being overridden
+    expected_output = run_mini_model_multimodal(
+        model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr
+    )
+
+    actual_output = run_mini_model_multimodal(
+        model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True
+    )
+
+    # Compare the loss of every step
+    assert_verbose_allclose(
+        torch.tensor([expected_output["loss"]]),
+        torch.tensor([actual_output["loss"]]),
+        atol=loss_atol,
+        rtol=loss_rtol,
+    )
+
+    # Compare the logits from the last step
+    assert_verbose_allclose(
+        expected_output["logits"],
+        actual_output["logits"],
+        atol=logits_atol,
+        rtol=logits_rtol,
+    )
+
+    # Compare the params from the last step
+    # Iterate over the model's parameters and compare them
+    for expected_param, actual_param in zip(
+        expected_output["model"].named_parameters(),
+        actual_output["model"].named_parameters(),
+    ):
+        assert_verbose_allclose(
+            expected_param[1], actual_param[1], atol=param_atol, rtol=param_rtol
+        )
diff --git a/test/convergence/test_mini_models_no_logits.py b/test/convergence/test_mini_models_no_logits.py
index 67ce443c6..540468849 100644
--- a/test/convergence/test_mini_models_no_logits.py
+++ b/test/convergence/test_mini_models_no_logits.py
@@ -27,8 +27,20 @@
     apply_liger_kernel_to_mixtral,
     apply_liger_kernel_to_phi3,
     apply_liger_kernel_to_qwen2,
+    apply_liger_kernel_to_qwen2_vl,
 )
 
+try:
+    # Qwen2-VL is only available in transformers>4.44.2
+    from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig
+    from transformers.models.qwen2_vl.modeling_qwen2_vl import (
+        Qwen2VLForConditionalGeneration,
+    )
+
+    QWEN2_VL_AVAILABLE = True
+except ImportError:
+    QWEN2_VL_AVAILABLE = False
+
 MINI_MODEL_SETUPS = {
     "mini_llama3": MiniModelConfig(
         liger_kernel_patch_func=apply_liger_kernel_to_llama,
@@ -249,6 +261,50 @@
     ),
 }
 
+if QWEN2_VL_AVAILABLE:
+    MINI_MODEL_SETUPS["mini_qwen2_vl"] = MiniModelConfig(
+        liger_kernel_patch_func=apply_liger_kernel_to_qwen2_vl,
+        model_class=Qwen2VLForConditionalGeneration,
+        mini_model_config=Qwen2VLConfig(
+            attention_dropout=0.0,
+            bos_token_id=1,  # 151643
+            eos_token_id=2,  # 151645
+            hidden_act="silu",
+            hidden_size=1536,  # 8192
+            initializer_range=0.02,
+            intermediate_size=4864,  # 29568
+            max_position_embeddings=32768,
+            max_window_layers=4,  # 80
+            num_attention_heads=12,  # 64
+            num_hidden_layers=4,  # 80
+            num_key_value_heads=2,  # 8
+            rms_norm_eps=1e-6,  # 1e-5
+            rope_theta=1000000.0,
+            rope_scaling=dict(
+                type="mrope",
+                mrope_section=[16, 24, 24],  # (temporal, height, width)
+            ),
+            sliding_window=4096,
+            tie_word_embeddings=False,
+            use_cache=True,
+            vocab_size=32000,  # 152064
+            use_sliding_window=False,
+            vision_config={
+                "depth": 4,  # 32
+                "embed_dim": 1280,
+                "mlp_ratio": 4,
+                "num_heads": 16,
+                "in_chans": 3,
+                "hidden_size": 128,  # 1536
+                "patch_size": 14,
+                "spatial_merge_size": 2,
+                "spatial_patch_size": 14,
+                "temporal_patch_size": 2,
+            },
+            attn_implementation="sdpa",
+        ),
+    )
+
 
 def create_model(model_name="mini_llama3"):
     """
@@ -276,9 +332,16 @@ def run_mini_model(
 
     if with_liger is True:
         kwargs = {
-            "rope": True,
             "rms_norm": True,
         }
+        model_supports_rope = "qwen2_vl" not in model_name
+        if model_supports_rope:
+            kwargs["rope"] = True
+
+        model_supports_layer_norm = "qwen2_vl" in model_name
+        if model_supports_layer_norm:
+            kwargs["layer_norm"] = True
+
         if "gemma" in model_name:
             kwargs["geglu"] = True
         else:
@@ -305,6 +368,7 @@ def run_mini_model(
 
     for i in range(num_steps):
         batch = next(loader_iter).to(model.device)
+        optimizer.zero_grad()
         output = model(**batch)
         output.loss.backward()
         optimizer.step()
@@ -349,6 +413,43 @@ def run_mini_model(
                 not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
             ),
         ),
+        pytest.param(
+            "mini_qwen2_vl",
+            32,
+            1e-4,
+            torch.float32,
+            1e-8,
+            1e-5,
+            5e-3,
+            1e-5,
+            5e-3,
+            1e-5,
+            marks=pytest.mark.skipif(
+                not QWEN2_VL_AVAILABLE,
+                reason="Qwen2-VL not available in this version of transformers",
+            ),
+        ),
+        pytest.param(
+            "mini_qwen2_vl",
+            32,
+            1e-4,
+            torch.bfloat16,
+            1e-8,
+            1e-5,
+            1e-2,
+            1e-5,
+            1e-2,
+            1e-5,
+            marks=[
+                pytest.mark.skipif(
+                    not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
+                ),
+                pytest.mark.skipif(
+                    not QWEN2_VL_AVAILABLE,
+                    reason="Qwen2-VL not available in this version of transformers",
+                ),
+            ],
+        ),
         ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
         pytest.param(
             "mini_phi3",
diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py
index 0db1d258a..4116287fd 100644
--- a/test/transformers/test_monkey_patch.py
+++ b/test/transformers/test_monkey_patch.py
@@ -22,6 +22,7 @@ def test_import_from_root():
             apply_liger_kernel_to_mixtral,
             apply_liger_kernel_to_phi3,
             apply_liger_kernel_to_qwen2,
+            apply_liger_kernel_to_qwen2_vl,
         )
     except Exception:
         pytest.fail("Import kernel patch from root fails")
diff --git a/test/utils.py b/test/utils.py
index cb66742e2..5ff28cdac 100644
--- a/test/utils.py
+++ b/test/utils.py
@@ -88,6 +88,10 @@ def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=
     os.path.dirname(os.path.abspath(__file__)), "resources/tiny_shakespeare_tokenized"
 )
 
+UNTOKENIZED_DATASET_PATH = os.path.join(
+    os.path.dirname(os.path.abspath(__file__)), "resources/tiny_shakespeare.txt"
+)
+
 
 @dataclass
 class MiniModelConfig:
@@ -114,6 +118,25 @@ def simple_collate_fn(data: List[Dict[str, Any]]):
     )
 
 
+def multimodal_collate_fn(data: List[Dict[str, Any]]):
+    """A collate function to use for DataLoader for multimodal models"""
+    batch = {}
+    keys = set(data[0].keys())
+
+    input_ids = torch.cat([torch.tensor(item["input_ids"]) for item in data])
+    keys.remove("input_ids")
+    batch["input_ids"] = input_ids
+
+    labels = input_ids.clone()
+    batch["labels"] = labels
+
+    # Collate all other keys, e.g. pixel_values, attention_mask, image_grid_thw, etc
+    for key in keys:
+        batch[key] = torch.cat([item[key] for item in data])
+
+    return BatchEncoding(batch)
+
+
 def supports_bfloat16():
     if not torch.cuda.is_available():
         return False