From b5d8cbf90d338ea2eda4e2e1863dcf0722599197 Mon Sep 17 00:00:00 2001 From: Tyler Romero <tyler.alexander.romero@gmail.com> Date: Sun, 8 Sep 2024 14:14:45 -0700 Subject: [PATCH] Monkeypatch for Qwen2-VL (#175) ## Summary Monkeypatch for the recently-published [Qwen2-VL](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). HF `transformers` modeling code: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py Feature Request: https://github.com/linkedin/Liger-Kernel/issues/165 ## Details Qwen2-VL in `transformers` is available on `transformers` main but is yet to be published in a release. ## Testing Done - Hardware Type: 4090 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang <tangshao28@gmail.com> --- .gitignore | 5 +- README.md | 3 +- src/liger_kernel/transformers/__init__.py | 1 + .../transformers/model/qwen2_vl.py | 172 ++++++++++ src/liger_kernel/transformers/monkey_patch.py | 50 +++ test/convergence/test_mini_models.py | 3 +- .../test_mini_models_multimodal.py | 310 ++++++++++++++++++ .../convergence/test_mini_models_no_logits.py | 103 +++++- test/transformers/test_monkey_patch.py | 1 + test/utils.py | 23 ++ 10 files changed, 667 insertions(+), 4 deletions(-) create mode 100644 src/liger_kernel/transformers/model/qwen2_vl.py create mode 100644 test/convergence/test_mini_models_multimodal.py diff --git a/.gitignore b/.gitignore index e643ae280..c84380ea4 100644 --- a/.gitignore +++ b/.gitignore @@ -12,5 +12,8 @@ site/ build/ dist/ +# Lockfiles +uv.lock + # Benchmark images -benchmark/visualizations +benchmark/visualizations \ No newline at end of file diff --git a/README.md b/README.md index 6d11f99c8..765dbec33 100644 --- a/README.md +++ b/README.md @@ -91,7 +91,7 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and - **Exact:** Computation is exact—no approximations! Both forward and backward passes are implemented with rigorous unit tests and undergo convergence testing against training runs without Liger Kernel to ensure accuracy. - **Lightweight:** Liger Kernel has minimal dependencies, requiring only Torch and Triton—no extra libraries needed! Say goodbye to dependency headaches! - **Multi-GPU supported:** Compatible with multi-GPU setups (PyTorch FSDP, DeepSpeed, DDP, etc.). -- **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860) +- **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860) ## Target Audiences @@ -227,6 +227,7 @@ loss.backward() | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss | | Qwen2 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | +| Qwen2-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | Phi3 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index 5c44a154b..147948b18 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -15,6 +15,7 @@ apply_liger_kernel_to_mixtral, apply_liger_kernel_to_phi3, apply_liger_kernel_to_qwen2, + apply_liger_kernel_to_qwen2_vl, ) from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401 from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401 diff --git a/src/liger_kernel/transformers/model/qwen2_vl.py b/src/liger_kernel/transformers/model/qwen2_vl.py new file mode 100644 index 000000000..cfb7a905b --- /dev/null +++ b/src/liger_kernel/transformers/model/qwen2_vl.py @@ -0,0 +1,172 @@ +from typing import List, Optional, Tuple, Union + +import torch +from torch.nn import CrossEntropyLoss +from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + _CONFIG_FOR_DOC, + QWEN2_VL_INPUTS_DOCSTRING, + Qwen2VLCausalLMOutputWithPast, +) +from transformers.utils import ( + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) + +from liger_kernel.transformers.fused_linear_cross_entropy import ( + LigerFusedLinearCrossEntropyLoss, +) + + +@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, +) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: + r""" + Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy + + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration + + >>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.get_dtype()) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).to( + inputs_embeds.device + ) + image_mask = input_ids == self.config.image_token_id + if self.training: + inputs_embeds = inputs_embeds.clone() + inputs_embeds[image_mask] = image_embeds + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw).to( + inputs_embeds.device + ) + video_mask = input_ids == self.config.video_token_id + inputs_embeds[video_mask] = video_embeds + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + loss = None + logits = None + + if self.training and (labels is not None): + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # Flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + lce = LigerFusedLinearCrossEntropyLoss() + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=rope_deltas, + ) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 6ecb670ef..1cca9753c 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -4,6 +4,7 @@ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.geglu import LigerGEGLUMLP +from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward @@ -245,6 +246,54 @@ def apply_liger_kernel_to_qwen2( modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP +def apply_liger_kernel_to_qwen2_vl( + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + layer_norm: bool = True, + swiglu: bool = True, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models. + NOTE: Qwen2-VL is not available in transformers<=4.44.2 + + Args: + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + """ + assert not ( + cross_entropy and fused_linear_cross_entropy + ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + + from transformers.models.qwen2_vl import modeling_qwen2_vl + + from liger_kernel.transformers.model.qwen2_vl import ( + lce_forward as qwen2_vl_lce_forward, + ) + + # TODO: Support Qwen2-VL's multimodal RoPE implementation + + if rms_norm: + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439 + modeling_qwen2_vl.Qwen2RMSNorm = partial( + LigerRMSNorm, init_fn="ones", casting_mode="gemma" + ) + if layer_norm: + modeling_qwen2_vl.LayerNorm = LigerLayerNorm + if cross_entropy: + modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss + if fused_linear_cross_entropy: + modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward + if swiglu: + modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP + + def apply_liger_kernel_to_phi3( rope: bool = True, cross_entropy: bool = False, @@ -291,6 +340,7 @@ def apply_liger_kernel_to_phi3( "mistral": apply_liger_kernel_to_mistral, "mixtral": apply_liger_kernel_to_mixtral, "qwen2": apply_liger_kernel_to_qwen2, + "qwen2_vl": apply_liger_kernel_to_qwen2_vl, "phi3": apply_liger_kernel_to_phi3, } diff --git a/test/convergence/test_mini_models.py b/test/convergence/test_mini_models.py index 95c832e15..a2b6f59ef 100644 --- a/test/convergence/test_mini_models.py +++ b/test/convergence/test_mini_models.py @@ -331,6 +331,7 @@ def run_mini_model( for i in range(num_steps): batch = next(loader_iter).to(model.device) + optimizer.zero_grad() output = model(**batch) output.loss.backward() optimizer.step() @@ -343,7 +344,7 @@ def run_mini_model( @pytest.mark.parametrize( "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol", [ - # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way) + # Gemma 1 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way) ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 6e-4, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_gemma1", diff --git a/test/convergence/test_mini_models_multimodal.py b/test/convergence/test_mini_models_multimodal.py new file mode 100644 index 000000000..9d77d43e6 --- /dev/null +++ b/test/convergence/test_mini_models_multimodal.py @@ -0,0 +1,310 @@ +import functools +import os +from test.utils import ( + UNTOKENIZED_DATASET_PATH, + MiniModelConfig, + assert_verbose_allclose, + multimodal_collate_fn, + set_seed, + supports_bfloat16, +) + +import pytest +import torch +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers.models.auto.processing_auto import AutoProcessor + +from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl + +try: + # Qwen2-VL is only available in transformers>4.44.2 + from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig + from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2VLForConditionalGeneration, + ) + + QWEN2_VL_AVAILABLE = True +except ImportError: + QWEN2_VL_AVAILABLE = False + +torch.use_deterministic_algorithms(True) + +# Only setting torch.use_deterministic_algorithms(True) throws the following error: +# RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`, +# but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an +# environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. For more information, +# go to https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + +TEST_IMAGE_DIM = 64 + +MINI_MODEL_SETUPS = {} + +if QWEN2_VL_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen2_vl"] = MiniModelConfig( + liger_kernel_patch_func=functools.partial( + apply_liger_kernel_to_qwen2_vl, fused_linear_cross_entropy=False + ), + model_class=Qwen2VLForConditionalGeneration, + mini_model_config=Qwen2VLConfig( + attention_dropout=0.0, + # Token Ids and vocab size must match those in the tokenizer/processor + # https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/blob/main/config.json + bos_token_id=151643, + eos_token_id=151645, + vision_start_token_id=151652, + vision_end_token_id=151653, + vision_token_id=151654, + image_token_id=151655, + hidden_act="silu", + hidden_size=1024, # 8192 + initializer_range=0.02, + intermediate_size=1024, # 29568 + max_position_embeddings=32768, + max_window_layers=4, # 80 + num_attention_heads=8, # 64 + num_hidden_layers=4, # 80 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-6, # 1e-5 + rope_theta=1000000.0, + rope_scaling=dict( + type="mrope", + mrope_section=[16, 24, 24], # (temporal, height, width) + ), + sliding_window=4096, + tie_word_embeddings=True, + use_cache=False, # True + vocab_size=152064, + use_sliding_window=False, + vision_config={ + "depth": 4, # 32 + "embed_dim": 128, # 1280 + "mlp_ratio": 1, + "num_heads": 8, # 16 + "in_chans": 3, + "hidden_size": 1024, # 1536 + }, + attn_implementation="sdpa", + ), + ) + + +def create_processor(model_name): + if model_name == "mini_qwen2_vl": + return AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + else: + raise ValueError(f"Processor not available for model {model_name}") + + +def create_multimodal_dataset(model_name: str): + processor = create_processor(model_name) + + def generate_procedural_image(example, index): + """Generate an image with a single row of white pixels at the index specified""" + image = torch.zeros(3, TEST_IMAGE_DIM, TEST_IMAGE_DIM) + image[:, index % TEST_IMAGE_DIM, :] = 255 + example["image"] = image + return example + + def apply_chat_template(example): + """ + Under the hood, this inserts the correct image placeholder token into the text. + More or less this conversation format is used by HF's mllms. The fact that it is + formatting as for IFT is not in-and-of-itself important here. + """ + conversation = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "Describe this image."}, + ], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": example["text"]}], + }, + ] + example["text"] = processor.apply_chat_template(conversation, tokenize=False) + return example + + def preprocess_function(examples): + """Tokenize text, preprocess images, and generate other relevant inputs for the model.""" + return processor( + text=examples["text"], + images=examples["image"], + padding="max_length", + truncation=True, + max_length=1024, # longer than for text-only b/c images require quite a few tokens + ) + + train_dataset = ( + load_dataset( + "text", data_files={"train": UNTOKENIZED_DATASET_PATH}, split="train" + ) + .to_iterable_dataset() # only map examples as-needed and on-demand + .map(generate_procedural_image, with_indices=True) + .map(apply_chat_template) + .map(preprocess_function, remove_columns=["text", "image"]) + ) + return train_dataset + + +def create_model(model_name): + """ + Create a mini version model + The commented values are the original values + """ + model_config = MINI_MODEL_SETUPS[model_name].mini_model_config + model_class = MINI_MODEL_SETUPS[model_name].model_class + return model_class(model_config) + + +def run_mini_model_multimodal( + model_name="mini_qwen2_vl", + num_steps=100, + dtype=torch.bfloat16, + lr=1e-5, + with_liger=False, +): + # If we move it to the beginning of test_mini_model, the two runs are initialized with different weights. + # This is due to RNG (Random Number Generator). The formula of RNG progression is x_(n+1) = (a * x_n + c) % m + # Everytime RNG is used, like randomly initialzing weight, the RNG progresses to the next state. + # Therefore, we have to reset RNG before we create the model to ensure the weight initialization started from the same RNG state. + + set_seed(42) + + if with_liger is True: + kwargs = { + "rms_norm": True, + "cross_entropy": True, + } + model_supports_rope = "qwen2_vl" not in model_name + if model_supports_rope: + kwargs["rope"] = True + + model_supports_layer_norm = "qwen2_vl" in model_name + if model_supports_layer_norm: + kwargs["layer_norm"] = True + + if "gemma" in model_name: + kwargs["geglu"] = True + else: + kwargs["swiglu"] = True + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) + + model = create_model(model_name).to(dtype).to("cuda") + model.gradient_checkpointing_enable() + + train_dataset = create_multimodal_dataset(model_name) + loader = DataLoader( + train_dataset, batch_size=2, shuffle=False, collate_fn=multimodal_collate_fn + ) + loader_iter = iter(loader) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr) + + loss_list = [] + + for i in range(num_steps): + batch = next(loader_iter).to(model.device) + optimizer.zero_grad() + output = model(**batch) + output.loss.backward() + optimizer.step() + + print(f"Step {i}, Loss: {output.loss.item()}") + loss_list.append(output.loss.item()) + + return {"loss": loss_list, "logits": output.logits, "model": model} + + +@pytest.mark.parametrize( + "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol", + [ + pytest.param( + "mini_qwen2_vl", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not QWEN2_VL_AVAILABLE, + reason="Qwen2-VL not available in this version of transformers", + ), + ), + pytest.param( + "mini_qwen2_vl", + 32, + 1e-4, + torch.bfloat16, + 1e-8, + 1e-5, + 1e-2, + 1e-5, + 1e-2, + 1e-5, + marks=[ + pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + pytest.mark.skipif( + not QWEN2_VL_AVAILABLE, + reason="Qwen2-VL not available in this version of transformers", + ), + ], + ), + ], +) +def test_mini_model_multimodal( + model_name, + num_steps, + lr, + dtype, + loss_atol, + loss_rtol, + logits_atol, + logits_rtol, + param_atol, + param_rtol, +): + # Non-liger models should be initialized and tested first to avoid the module being overridden + expected_output = run_mini_model_multimodal( + model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr + ) + + actual_output = run_mini_model_multimodal( + model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True + ) + + # Compare the loss of every step + assert_verbose_allclose( + torch.tensor([expected_output["loss"]]), + torch.tensor([actual_output["loss"]]), + atol=loss_atol, + rtol=loss_rtol, + ) + + # Compare the logits from the last step + assert_verbose_allclose( + expected_output["logits"], + actual_output["logits"], + atol=logits_atol, + rtol=logits_rtol, + ) + + # Compare the params from the last step + # Iterate over the model's parameters and compare them + for expected_param, actual_param in zip( + expected_output["model"].named_parameters(), + actual_output["model"].named_parameters(), + ): + assert_verbose_allclose( + expected_param[1], actual_param[1], atol=param_atol, rtol=param_rtol + ) diff --git a/test/convergence/test_mini_models_no_logits.py b/test/convergence/test_mini_models_no_logits.py index 67ce443c6..540468849 100644 --- a/test/convergence/test_mini_models_no_logits.py +++ b/test/convergence/test_mini_models_no_logits.py @@ -27,8 +27,20 @@ apply_liger_kernel_to_mixtral, apply_liger_kernel_to_phi3, apply_liger_kernel_to_qwen2, + apply_liger_kernel_to_qwen2_vl, ) +try: + # Qwen2-VL is only available in transformers>4.44.2 + from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig + from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2VLForConditionalGeneration, + ) + + QWEN2_VL_AVAILABLE = True +except ImportError: + QWEN2_VL_AVAILABLE = False + MINI_MODEL_SETUPS = { "mini_llama3": MiniModelConfig( liger_kernel_patch_func=apply_liger_kernel_to_llama, @@ -249,6 +261,50 @@ ), } +if QWEN2_VL_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen2_vl"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen2_vl, + model_class=Qwen2VLForConditionalGeneration, + mini_model_config=Qwen2VLConfig( + attention_dropout=0.0, + bos_token_id=1, # 151643 + eos_token_id=2, # 151645 + hidden_act="silu", + hidden_size=1536, # 8192 + initializer_range=0.02, + intermediate_size=4864, # 29568 + max_position_embeddings=32768, + max_window_layers=4, # 80 + num_attention_heads=12, # 64 + num_hidden_layers=4, # 80 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-6, # 1e-5 + rope_theta=1000000.0, + rope_scaling=dict( + type="mrope", + mrope_section=[16, 24, 24], # (temporal, height, width) + ), + sliding_window=4096, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 152064 + use_sliding_window=False, + vision_config={ + "depth": 4, # 32 + "embed_dim": 1280, + "mlp_ratio": 4, + "num_heads": 16, + "in_chans": 3, + "hidden_size": 128, # 1536 + "patch_size": 14, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "temporal_patch_size": 2, + }, + attn_implementation="sdpa", + ), + ) + def create_model(model_name="mini_llama3"): """ @@ -276,9 +332,16 @@ def run_mini_model( if with_liger is True: kwargs = { - "rope": True, "rms_norm": True, } + model_supports_rope = "qwen2_vl" not in model_name + if model_supports_rope: + kwargs["rope"] = True + + model_supports_layer_norm = "qwen2_vl" in model_name + if model_supports_layer_norm: + kwargs["layer_norm"] = True + if "gemma" in model_name: kwargs["geglu"] = True else: @@ -305,6 +368,7 @@ def run_mini_model( for i in range(num_steps): batch = next(loader_iter).to(model.device) + optimizer.zero_grad() output = model(**batch) output.loss.backward() optimizer.step() @@ -349,6 +413,43 @@ def run_mini_model( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), + pytest.param( + "mini_qwen2_vl", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not QWEN2_VL_AVAILABLE, + reason="Qwen2-VL not available in this version of transformers", + ), + ), + pytest.param( + "mini_qwen2_vl", + 32, + 1e-4, + torch.bfloat16, + 1e-8, + 1e-5, + 1e-2, + 1e-5, + 1e-2, + 1e-5, + marks=[ + pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + pytest.mark.skipif( + not QWEN2_VL_AVAILABLE, + reason="Qwen2-VL not available in this version of transformers", + ), + ], + ), ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_phi3", diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index 0db1d258a..4116287fd 100644 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -22,6 +22,7 @@ def test_import_from_root(): apply_liger_kernel_to_mixtral, apply_liger_kernel_to_phi3, apply_liger_kernel_to_qwen2, + apply_liger_kernel_to_qwen2_vl, ) except Exception: pytest.fail("Import kernel patch from root fails") diff --git a/test/utils.py b/test/utils.py index cb66742e2..5ff28cdac 100644 --- a/test/utils.py +++ b/test/utils.py @@ -88,6 +88,10 @@ def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print= os.path.dirname(os.path.abspath(__file__)), "resources/tiny_shakespeare_tokenized" ) +UNTOKENIZED_DATASET_PATH = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "resources/tiny_shakespeare.txt" +) + @dataclass class MiniModelConfig: @@ -114,6 +118,25 @@ def simple_collate_fn(data: List[Dict[str, Any]]): ) +def multimodal_collate_fn(data: List[Dict[str, Any]]): + """A collate function to use for DataLoader for multimodal models""" + batch = {} + keys = set(data[0].keys()) + + input_ids = torch.cat([torch.tensor(item["input_ids"]) for item in data]) + keys.remove("input_ids") + batch["input_ids"] = input_ids + + labels = input_ids.clone() + batch["labels"] = labels + + # Collate all other keys, e.g. pixel_values, attention_mask, image_grid_thw, etc + for key in keys: + batch[key] = torch.cat([item[key] for item in data]) + + return BatchEncoding(batch) + + def supports_bfloat16(): if not torch.cuda.is_available(): return False