Skip to content

Commit

Permalink
Monkeypatch for Qwen2-VL (linkedin#175)
Browse files Browse the repository at this point in the history
## Summary
Monkeypatch for the recently-published
[Qwen2-VL](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
HF `transformers` modeling code:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py

Feature Request: linkedin#165

## Details
Qwen2-VL in `transformers` is available on `transformers` main but is
yet to be published in a release.

## Testing Done
- Hardware Type: 4090
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence

---------

Co-authored-by: Shao Tang <tangshao28@gmail.com>
  • Loading branch information
tyler-romero and lancerts authored Sep 8, 2024
1 parent 9250546 commit b5d8cbf
Show file tree
Hide file tree
Showing 10 changed files with 667 additions and 4 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,8 @@ site/
build/
dist/

# Lockfiles
uv.lock

# Benchmark images
benchmark/visualizations
benchmark/visualizations
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
- **Exact:** Computation is exact—no approximations! Both forward and backward passes are implemented with rigorous unit tests and undergo convergence testing against training runs without Liger Kernel to ensure accuracy.
- **Lightweight:** Liger Kernel has minimal dependencies, requiring only Torch and Triton—no extra libraries needed! Say goodbye to dependency headaches!
- **Multi-GPU supported:** Compatible with multi-GPU setups (PyTorch FSDP, DeepSpeed, DDP, etc.).
- **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860)
- **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860)

## Target Audiences

Expand Down Expand Up @@ -227,6 +227,7 @@ loss.backward()
| Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss |
| Qwen2 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Qwen2-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Phi3 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |


Expand Down
1 change: 1 addition & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
apply_liger_kernel_to_mixtral,
apply_liger_kernel_to_phi3,
apply_liger_kernel_to_qwen2,
apply_liger_kernel_to_qwen2_vl,
)
from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
Expand Down
172 changes: 172 additions & 0 deletions src/liger_kernel/transformers/model/qwen2_vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from typing import List, Optional, Tuple, Union

import torch
from torch.nn import CrossEntropyLoss
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
_CONFIG_FOR_DOC,
QWEN2_VL_INPUTS_DOCSTRING,
Qwen2VLCausalLMOutputWithPast,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)

from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
)


@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
r"""
Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
>>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
>>> messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
```"""

output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)

if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.get_dtype())
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).to(
inputs_embeds.device
)
image_mask = input_ids == self.config.image_token_id
if self.training:
inputs_embeds = inputs_embeds.clone()
inputs_embeds[image_mask] = image_embeds
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw).to(
inputs_embeds.device
)
video_mask = input_ids == self.config.video_token_id
inputs_embeds[video_mask] = video_embeds
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)

outputs = self.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

hidden_states = outputs[0]

loss = None
logits = None

if self.training and (labels is not None):
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

# Flatten tokens
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
shift_labels = shift_labels.view(-1)

lce = LigerFusedLinearCrossEntropyLoss()
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)

if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output

return Qwen2VLCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=rope_deltas,
)
50 changes: 50 additions & 0 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
Expand Down Expand Up @@ -245,6 +246,54 @@ def apply_liger_kernel_to_qwen2(
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP


def apply_liger_kernel_to_qwen2_vl(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
layer_norm: bool = True,
swiglu: bool = True,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
NOTE: Qwen2-VL is not available in transformers<=4.44.2
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
"""
assert not (
cross_entropy and fused_linear_cross_entropy
), "cross_entropy and fused_linear_cross_entropy cannot both be True."

from transformers.models.qwen2_vl import modeling_qwen2_vl

from liger_kernel.transformers.model.qwen2_vl import (
lce_forward as qwen2_vl_lce_forward,
)

# TODO: Support Qwen2-VL's multimodal RoPE implementation

if rms_norm:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
modeling_qwen2_vl.Qwen2RMSNorm = partial(
LigerRMSNorm, init_fn="ones", casting_mode="gemma"
)
if layer_norm:
modeling_qwen2_vl.LayerNorm = LigerLayerNorm
if cross_entropy:
modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
if swiglu:
modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP


def apply_liger_kernel_to_phi3(
rope: bool = True,
cross_entropy: bool = False,
Expand Down Expand Up @@ -291,6 +340,7 @@ def apply_liger_kernel_to_phi3(
"mistral": apply_liger_kernel_to_mistral,
"mixtral": apply_liger_kernel_to_mixtral,
"qwen2": apply_liger_kernel_to_qwen2,
"qwen2_vl": apply_liger_kernel_to_qwen2_vl,
"phi3": apply_liger_kernel_to_phi3,
}

Expand Down
3 changes: 2 additions & 1 deletion test/convergence/test_mini_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def run_mini_model(

for i in range(num_steps):
batch = next(loader_iter).to(model.device)
optimizer.zero_grad()
output = model(**batch)
output.loss.backward()
optimizer.step()
Expand All @@ -343,7 +344,7 @@ def run_mini_model(
@pytest.mark.parametrize(
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
[
# Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
# Gemma 1 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 6e-4, 5e-3, 1e-5, 5e-3, 1e-5),
pytest.param(
"mini_gemma1",
Expand Down
Loading

0 comments on commit b5d8cbf

Please sign in to comment.