Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Monkeypatch for Qwen2-VL #175

Merged
merged 26 commits into from
Sep 8, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b2ec858
Monkeypatch for Qwen2-VL
tyler-romero Aug 30, 2024
308b2a8
Checkstyle
tyler-romero Aug 30, 2024
95b4b1d
Revert automodel
tyler-romero Aug 30, 2024
16da754
Merge branch 'main' into tyler/monkeypatch-qwen2vl
lancerts Aug 30, 2024
956eea0
wip
tyler-romero Sep 2, 2024
43d2fe0
Passing tests...
tyler-romero Sep 3, 2024
baae5fd
Fix tests w/ sdpa
tyler-romero Sep 3, 2024
37ee685
Shield imports, checkstyle
tyler-romero Sep 3, 2024
f21fe69
poke tests
tyler-romero Sep 3, 2024
03bc036
Revert setup.py
tyler-romero Sep 3, 2024
c4179eb
Add marks to pytest for lower transformers versions
tyler-romero Sep 3, 2024
1155d1b
Checkstyle
tyler-romero Sep 3, 2024
8c5182b
Revert comment
tyler-romero Sep 3, 2024
5125cf1
Working multimodal convergence test
tyler-romero Sep 5, 2024
904704b
Checkstyle
tyler-romero Sep 5, 2024
c9583d2
cleanup
tyler-romero Sep 5, 2024
dca3451
Merge branch 'main' into tyler/monkeypatch-qwen2vl
tyler-romero Sep 5, 2024
175abe8
Add uv.lock to gitignore
tyler-romero Sep 5, 2024
f9422a9
Clean up multimodal collation
tyler-romero Sep 5, 2024
3bd22bc
Update readme
tyler-romero Sep 6, 2024
f590993
correction
tyler-romero Sep 6, 2024
1bf506f
Merge branch 'main' into tyler/monkeypatch-qwen2vl
lancerts Sep 6, 2024
7b30646
Merge branch 'main' into tyler/monkeypatch-qwen2vl
lancerts Sep 7, 2024
929f30c
Merge branch 'main' into tyler/monkeypatch-qwen2vl
lancerts Sep 7, 2024
cf90fdd
Merge branch 'main' into tyler/monkeypatch-qwen2vl
tyler-romero Sep 7, 2024
292fef0
Merge branch 'main' into tyler/monkeypatch-qwen2vl
tyler-romero Sep 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
apply_liger_kernel_to_mixtral,
apply_liger_kernel_to_phi3,
apply_liger_kernel_to_qwen2,
apply_liger_kernel_to_qwen2_vl,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this break users if transformers < 4.44?

Copy link
Collaborator Author

@tyler-romero tyler-romero Sep 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This import won't break, but if someone tries to call this function with a lower version of transformers then they would get an import error as it tries to import the Qwen2-VL module internally.

)
172 changes: 172 additions & 0 deletions src/liger_kernel/transformers/model/qwen2_vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from typing import List, Optional, Tuple, Union

import torch
from torch.nn import CrossEntropyLoss
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. will it break users with lower version?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only if they explicitly import this module - and I would argue it should throw an ImportError in that case.

_CONFIG_FOR_DOC,
QWEN2_VL_INPUTS_DOCSTRING,
Qwen2VLCausalLMOutputWithPast,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)

from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
)


@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
r"""
Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
>>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
>>> messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
```"""

output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)

if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.get_dtype())
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).to(
inputs_embeds.device
)
image_mask = input_ids == self.config.image_token_id
if self.training:
inputs_embeds = inputs_embeds.clone()
inputs_embeds[image_mask] = image_embeds
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw).to(
inputs_embeds.device
)
video_mask = input_ids == self.config.video_token_id
inputs_embeds[video_mask] = video_embeds
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)

outputs = self.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

hidden_states = outputs[0]

loss = None
logits = None

if self.training and (labels is not None):
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

# Flatten tokens
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
shift_labels = shift_labels.view(-1)

lce = LigerFusedLinearCrossEntropyLoss()
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)

if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output

return Qwen2VLCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=rope_deltas,
)
50 changes: 50 additions & 0 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
Expand Down Expand Up @@ -233,6 +234,54 @@ def apply_liger_kernel_to_qwen2(
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP


def apply_liger_kernel_to_qwen2_vl(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
layer_norm: bool = True,
swiglu: bool = True,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
NOTE: Qwen2-VL is not available in transformers<=4.44.2

Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
"""
assert not (
cross_entropy and fused_linear_cross_entropy
), "cross_entropy and fused_linear_cross_entropy cannot both be True."

from transformers.models.qwen2_vl import modeling_qwen2_vl

from liger_kernel.transformers.model.qwen2_vl import (
lce_forward as qwen2_vl_lce_forward,
)

# TODO: Support Qwen2-VL's multimodal RoPE implementation

if rms_norm:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
modeling_qwen2_vl.Qwen2RMSNorm = partial(
LigerRMSNorm, init_fn="ones", casting_mode="gemma"
)
if layer_norm:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LayerNorm is only used in the ViT part of the model.

Q for reviewers: We don't currently have image inputs as part of the convergence test suite. Worth implementing? Should I just not support LayerNorm for now so we don't need to worry about it? Perhaps default to False?

Copy link
Contributor

@ryankert01 ryankert01 Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the question is will we support more multi-modal models? If yes, I think it's worthwhile to implement convergence test of it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes please do

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done - took some doing b/c different VLLMs have different image tokens, different numbers of image tokens, different processor outputs and model inputs, etc. I think I arrived at something that is pretty extendable to other VLLMs, but there are some differences to the other convergence tests - for example it generates images and tokenizes the test dataset on the fly.

modeling_qwen2_vl.LayerNorm = LigerLayerNorm
if cross_entropy:
modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not AutoModelForCausalLMs like the others.

if swiglu:
modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP


def apply_liger_kernel_to_phi3(
rope: bool = True,
cross_entropy: bool = False,
Expand Down Expand Up @@ -279,6 +328,7 @@ def apply_liger_kernel_to_phi3(
"mistral": apply_liger_kernel_to_mistral,
"mixtral": apply_liger_kernel_to_mixtral,
"qwen2": apply_liger_kernel_to_qwen2,
"qwen2_vl": apply_liger_kernel_to_qwen2_vl,
"phi3": apply_liger_kernel_to_phi3,
}

Expand Down
106 changes: 104 additions & 2 deletions test/convergence/test_mini_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,20 @@
apply_liger_kernel_to_mixtral,
apply_liger_kernel_to_phi3,
apply_liger_kernel_to_qwen2,
apply_liger_kernel_to_qwen2_vl,
)

try:
# Qwen2-VL is only available in transformers>4.44.2
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests pass w/ transformers @ HEAD on 4090

from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VLForConditionalGeneration,
)

QWEN2_VL_AVAILABLE = True
except ImportError:
QWEN2_VL_AVAILABLE = False

torch.use_deterministic_algorithms(True)

# Only setting torch.use_deterministic_algorithms(True) throws the following error:
Expand Down Expand Up @@ -281,6 +293,52 @@
),
}

if QWEN2_VL_AVAILABLE:
MINI_MODEL_SETUPS["mini_qwen2_vl"] = MiniModelConfig(
liger_kernel_patch_func=functools.partial(
apply_liger_kernel_to_qwen2_vl, fused_linear_cross_entropy=False
),
model_class=Qwen2VLForConditionalGeneration,
mini_model_config=Qwen2VLConfig(
attention_dropout=0.0,
bos_token_id=1, # 151643
eos_token_id=2, # 151645
hidden_act="silu",
hidden_size=1536, # 8192
initializer_range=0.02,
intermediate_size=4864, # 29568
max_position_embeddings=32768,
max_window_layers=4, # 80
num_attention_heads=12, # 64
num_hidden_layers=4, # 80
num_key_value_heads=2, # 8
rms_norm_eps=1e-6, # 1e-5
rope_theta=1000000.0,
rope_scaling=dict(
type="mrope",
mrope_section=[16, 24, 24], # (temporal, height, width)
),
sliding_window=4096,
tie_word_embeddings=False,
use_cache=True,
vocab_size=32000, # 152064
use_sliding_window=False,
vision_config={
"depth": 4, # 32
"embed_dim": 1280,
"mlp_ratio": 4,
"num_heads": 16,
"in_chans": 3,
"hidden_size": 128, # 1536
"patch_size": 14,
"spatial_merge_size": 2,
"spatial_patch_size": 14,
"temporal_patch_size": 2,
},
attn_implementation="sdpa",
),
)


def create_model(model_name="mini_llama3"):
"""
Expand Down Expand Up @@ -308,10 +366,17 @@ def run_mini_model(

if with_liger is True:
kwargs = {
"rope": True,
"rms_norm": True,
"cross_entropy": True,
}
model_supports_rope = "qwen2_vl" not in model_name
if model_supports_rope:
kwargs["rope"] = True

model_supports_layer_norm = "qwen2_vl" in model_name
if model_supports_layer_norm:
kwargs["layer_norm"] = True

if "gemma" in model_name:
kwargs["geglu"] = True
else:
Expand Down Expand Up @@ -343,7 +408,7 @@ def run_mini_model(
@pytest.mark.parametrize(
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
[
# Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
# Gemma 1 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 6e-4, 5e-3, 1e-5, 5e-3, 1e-5),
pytest.param(
"mini_gemma1",
Expand Down Expand Up @@ -444,6 +509,43 @@ def run_mini_model(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
"mini_qwen2_vl",
32,
1e-4,
torch.float32,
1e-8,
1e-5,
5e-3,
1e-5,
5e-3,
1e-5,
marks=pytest.mark.skipif(
not QWEN2_VL_AVAILABLE,
reason="Qwen2-VL not available in this version of transformers",
),
),
pytest.param(
"mini_qwen2_vl",
32,
1e-4,
torch.bfloat16,
1e-8,
1e-5,
1e-2,
1e-5,
1e-2,
1e-5,
marks=[
pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
pytest.mark.skipif(
not QWEN2_VL_AVAILABLE,
reason="Qwen2-VL not available in this version of transformers",
),
],
),
("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
pytest.param(
"mini_phi3",
Expand Down
Loading
Loading