-
Notifications
You must be signed in to change notification settings - Fork 230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Monkeypatch for Qwen2-VL #175
Changes from 13 commits
b2ec858
308b2a8
95b4b1d
16da754
956eea0
43d2fe0
baae5fd
37ee685
f21fe69
03bc036
c4179eb
1155d1b
8c5182b
5125cf1
904704b
c9583d2
dca3451
175abe8
f9422a9
3bd22bc
f590993
1bf506f
7b30646
929f30c
cf90fdd
292fef0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
from typing import List, Optional, Tuple, Union | ||
|
||
import torch | ||
from torch.nn import CrossEntropyLoss | ||
from transformers.models.qwen2_vl.modeling_qwen2_vl import ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto. will it break users with lower version? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only if they explicitly import this module - and I would argue it should throw an ImportError in that case. |
||
_CONFIG_FOR_DOC, | ||
QWEN2_VL_INPUTS_DOCSTRING, | ||
Qwen2VLCausalLMOutputWithPast, | ||
) | ||
from transformers.utils import ( | ||
add_start_docstrings_to_model_forward, | ||
replace_return_docstrings, | ||
) | ||
|
||
from liger_kernel.transformers.fused_linear_cross_entropy import ( | ||
LigerFusedLinearCrossEntropyLoss, | ||
) | ||
|
||
|
||
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING) | ||
@replace_return_docstrings( | ||
output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC | ||
) | ||
def lce_forward( | ||
self, | ||
input_ids: torch.LongTensor = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
past_key_values: Optional[List[torch.FloatTensor]] = None, | ||
inputs_embeds: Optional[torch.FloatTensor] = None, | ||
labels: Optional[torch.LongTensor] = None, | ||
use_cache: Optional[bool] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
pixel_values: Optional[torch.Tensor] = None, | ||
pixel_values_videos: Optional[torch.FloatTensor] = None, | ||
image_grid_thw: Optional[torch.LongTensor] = None, | ||
video_grid_thw: Optional[torch.LongTensor] = None, | ||
rope_deltas: Optional[torch.LongTensor] = None, | ||
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: | ||
r""" | ||
Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy | ||
Args: | ||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | ||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | ||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | ||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | ||
Returns: | ||
Example: | ||
```python | ||
>>> from PIL import Image | ||
>>> import requests | ||
>>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration | ||
>>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") | ||
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") | ||
>>> messages = [ | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{"type": "image"}, | ||
{"type": "text", "text": "What is shown in this image?"}, | ||
], | ||
}, | ||
] | ||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" | ||
>>> image = Image.open(requests.get(url, stream=True).raw) | ||
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | ||
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) | ||
>>> # Generate | ||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) | ||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | ||
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." | ||
```""" | ||
|
||
output_attentions = ( | ||
output_attentions | ||
if output_attentions is not None | ||
else self.config.output_attentions | ||
) | ||
output_hidden_states = ( | ||
output_hidden_states | ||
if output_hidden_states is not None | ||
else self.config.output_hidden_states | ||
) | ||
return_dict = ( | ||
return_dict if return_dict is not None else self.config.use_return_dict | ||
) | ||
|
||
if inputs_embeds is None: | ||
inputs_embeds = self.model.embed_tokens(input_ids) | ||
if pixel_values is not None: | ||
pixel_values = pixel_values.type(self.visual.get_dtype()) | ||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).to( | ||
inputs_embeds.device | ||
) | ||
image_mask = input_ids == self.config.image_token_id | ||
if self.training: | ||
inputs_embeds = inputs_embeds.clone() | ||
inputs_embeds[image_mask] = image_embeds | ||
if pixel_values_videos is not None: | ||
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) | ||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw).to( | ||
inputs_embeds.device | ||
) | ||
video_mask = input_ids == self.config.video_token_id | ||
inputs_embeds[video_mask] = video_embeds | ||
if attention_mask is not None: | ||
attention_mask = attention_mask.to(inputs_embeds.device) | ||
|
||
outputs = self.model( | ||
input_ids=None, | ||
position_ids=position_ids, | ||
attention_mask=attention_mask, | ||
past_key_values=past_key_values, | ||
inputs_embeds=inputs_embeds, | ||
use_cache=use_cache, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
) | ||
|
||
hidden_states = outputs[0] | ||
|
||
loss = None | ||
logits = None | ||
|
||
if self.training and (labels is not None): | ||
shift_hidden_states = hidden_states[..., :-1, :].contiguous() | ||
shift_labels = labels[..., 1:].contiguous() | ||
|
||
# Flatten tokens | ||
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) | ||
shift_labels = shift_labels.view(-1) | ||
|
||
lce = LigerFusedLinearCrossEntropyLoss() | ||
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) | ||
else: | ||
logits = self.lm_head(hidden_states) | ||
logits = logits.float() | ||
if labels is not None: | ||
# Shift so that tokens < n predict n | ||
shift_logits = logits[..., :-1, :].contiguous() | ||
shift_labels = labels[..., 1:].contiguous() | ||
# Flatten the tokens | ||
loss_fct = CrossEntropyLoss() | ||
shift_logits = shift_logits.view(-1, self.config.vocab_size) | ||
shift_labels = shift_labels.view(-1) | ||
# Enable model parallelism | ||
shift_labels = shift_labels.to(shift_logits.device) | ||
loss = loss_fct(shift_logits, shift_labels) | ||
|
||
if not return_dict: | ||
output = (logits,) + outputs[1:] | ||
return (loss,) + output if loss is not None else output | ||
|
||
return Qwen2VLCausalLMOutputWithPast( | ||
loss=loss, | ||
logits=logits, | ||
past_key_values=outputs.past_key_values, | ||
hidden_states=outputs.hidden_states, | ||
attentions=outputs.attentions, | ||
rope_deltas=rope_deltas, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
|
||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss | ||
from liger_kernel.transformers.geglu import LigerGEGLUMLP | ||
from liger_kernel.transformers.layer_norm import LigerLayerNorm | ||
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward | ||
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward | ||
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward | ||
|
@@ -233,6 +234,54 @@ def apply_liger_kernel_to_qwen2( | |
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP | ||
|
||
|
||
def apply_liger_kernel_to_qwen2_vl( | ||
cross_entropy: bool = False, | ||
fused_linear_cross_entropy: bool = True, | ||
rms_norm: bool = True, | ||
layer_norm: bool = True, | ||
swiglu: bool = True, | ||
) -> None: | ||
""" | ||
Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models. | ||
NOTE: Qwen2-VL is not available in transformers<=4.44.2 | ||
|
||
Args: | ||
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. | ||
fused_linear_cross_entropy (bool): | ||
Whether to apply Liger's fused linear cross entropy loss. Default is True. | ||
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True. | ||
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. | ||
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. | ||
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True. | ||
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. | ||
""" | ||
assert not ( | ||
cross_entropy and fused_linear_cross_entropy | ||
), "cross_entropy and fused_linear_cross_entropy cannot both be True." | ||
|
||
from transformers.models.qwen2_vl import modeling_qwen2_vl | ||
|
||
from liger_kernel.transformers.model.qwen2_vl import ( | ||
lce_forward as qwen2_vl_lce_forward, | ||
) | ||
|
||
# TODO: Support Qwen2-VL's multimodal RoPE implementation | ||
|
||
if rms_norm: | ||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439 | ||
modeling_qwen2_vl.Qwen2RMSNorm = partial( | ||
LigerRMSNorm, init_fn="ones", casting_mode="gemma" | ||
) | ||
if layer_norm: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LayerNorm is only used in the ViT part of the model. Q for reviewers: We don't currently have image inputs as part of the convergence test suite. Worth implementing? Should I just not support LayerNorm for now so we don't need to worry about it? Perhaps default to False? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the question is will we support more multi-modal models? If yes, I think it's worthwhile to implement convergence test of it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes please do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done - took some doing b/c different VLLMs have different image tokens, different numbers of image tokens, different processor outputs and model inputs, etc. I think I arrived at something that is pretty extendable to other VLLMs, but there are some differences to the other convergence tests - for example it generates images and tokenizes the test dataset on the fly. |
||
modeling_qwen2_vl.LayerNorm = LigerLayerNorm | ||
if cross_entropy: | ||
modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss | ||
if fused_linear_cross_entropy: | ||
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not AutoModelForCausalLMs like the others. |
||
if swiglu: | ||
modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP | ||
|
||
|
||
def apply_liger_kernel_to_phi3( | ||
rope: bool = True, | ||
cross_entropy: bool = False, | ||
|
@@ -279,6 +328,7 @@ def apply_liger_kernel_to_phi3( | |
"mistral": apply_liger_kernel_to_mistral, | ||
"mixtral": apply_liger_kernel_to_mixtral, | ||
"qwen2": apply_liger_kernel_to_qwen2, | ||
"qwen2_vl": apply_liger_kernel_to_qwen2_vl, | ||
"phi3": apply_liger_kernel_to_phi3, | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,8 +29,20 @@ | |
apply_liger_kernel_to_mixtral, | ||
apply_liger_kernel_to_phi3, | ||
apply_liger_kernel_to_qwen2, | ||
apply_liger_kernel_to_qwen2_vl, | ||
) | ||
|
||
try: | ||
# Qwen2-VL is only available in transformers>4.44.2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tests pass w/ transformers @ HEAD on 4090 |
||
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig | ||
from transformers.models.qwen2_vl.modeling_qwen2_vl import ( | ||
Qwen2VLForConditionalGeneration, | ||
) | ||
|
||
QWEN2_VL_AVAILABLE = True | ||
except ImportError: | ||
QWEN2_VL_AVAILABLE = False | ||
|
||
torch.use_deterministic_algorithms(True) | ||
|
||
# Only setting torch.use_deterministic_algorithms(True) throws the following error: | ||
|
@@ -281,6 +293,52 @@ | |
), | ||
} | ||
|
||
if QWEN2_VL_AVAILABLE: | ||
MINI_MODEL_SETUPS["mini_qwen2_vl"] = MiniModelConfig( | ||
liger_kernel_patch_func=functools.partial( | ||
apply_liger_kernel_to_qwen2_vl, fused_linear_cross_entropy=False | ||
), | ||
model_class=Qwen2VLForConditionalGeneration, | ||
mini_model_config=Qwen2VLConfig( | ||
attention_dropout=0.0, | ||
bos_token_id=1, # 151643 | ||
eos_token_id=2, # 151645 | ||
hidden_act="silu", | ||
hidden_size=1536, # 8192 | ||
initializer_range=0.02, | ||
intermediate_size=4864, # 29568 | ||
max_position_embeddings=32768, | ||
max_window_layers=4, # 80 | ||
num_attention_heads=12, # 64 | ||
num_hidden_layers=4, # 80 | ||
num_key_value_heads=2, # 8 | ||
rms_norm_eps=1e-6, # 1e-5 | ||
rope_theta=1000000.0, | ||
rope_scaling=dict( | ||
type="mrope", | ||
mrope_section=[16, 24, 24], # (temporal, height, width) | ||
), | ||
sliding_window=4096, | ||
tie_word_embeddings=False, | ||
use_cache=True, | ||
vocab_size=32000, # 152064 | ||
use_sliding_window=False, | ||
vision_config={ | ||
"depth": 4, # 32 | ||
"embed_dim": 1280, | ||
"mlp_ratio": 4, | ||
"num_heads": 16, | ||
"in_chans": 3, | ||
"hidden_size": 128, # 1536 | ||
"patch_size": 14, | ||
"spatial_merge_size": 2, | ||
"spatial_patch_size": 14, | ||
"temporal_patch_size": 2, | ||
}, | ||
attn_implementation="sdpa", | ||
), | ||
) | ||
|
||
|
||
def create_model(model_name="mini_llama3"): | ||
""" | ||
|
@@ -308,10 +366,17 @@ def run_mini_model( | |
|
||
if with_liger is True: | ||
kwargs = { | ||
"rope": True, | ||
"rms_norm": True, | ||
"cross_entropy": True, | ||
} | ||
model_supports_rope = "qwen2_vl" not in model_name | ||
if model_supports_rope: | ||
kwargs["rope"] = True | ||
|
||
model_supports_layer_norm = "qwen2_vl" in model_name | ||
if model_supports_layer_norm: | ||
kwargs["layer_norm"] = True | ||
|
||
if "gemma" in model_name: | ||
kwargs["geglu"] = True | ||
else: | ||
|
@@ -343,7 +408,7 @@ def run_mini_model( | |
@pytest.mark.parametrize( | ||
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol", | ||
[ | ||
# Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way) | ||
# Gemma 1 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way) | ||
("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 6e-4, 5e-3, 1e-5, 5e-3, 1e-5), | ||
pytest.param( | ||
"mini_gemma1", | ||
|
@@ -444,6 +509,43 @@ def run_mini_model( | |
not supports_bfloat16(), reason="bfloat16 not supported on this GPU" | ||
), | ||
), | ||
pytest.param( | ||
"mini_qwen2_vl", | ||
32, | ||
1e-4, | ||
torch.float32, | ||
1e-8, | ||
1e-5, | ||
5e-3, | ||
1e-5, | ||
5e-3, | ||
1e-5, | ||
marks=pytest.mark.skipif( | ||
not QWEN2_VL_AVAILABLE, | ||
reason="Qwen2-VL not available in this version of transformers", | ||
), | ||
), | ||
pytest.param( | ||
"mini_qwen2_vl", | ||
32, | ||
1e-4, | ||
torch.bfloat16, | ||
1e-8, | ||
1e-5, | ||
1e-2, | ||
1e-5, | ||
1e-2, | ||
1e-5, | ||
marks=[ | ||
pytest.mark.skipif( | ||
not supports_bfloat16(), reason="bfloat16 not supported on this GPU" | ||
), | ||
pytest.mark.skipif( | ||
not QWEN2_VL_AVAILABLE, | ||
reason="Qwen2-VL not available in this version of transformers", | ||
), | ||
], | ||
), | ||
("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), | ||
pytest.param( | ||
"mini_phi3", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will this break users if transformers < 4.44?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This import won't break, but if someone tries to call this function with a lower version of transformers then they would get an import error as it tries to import the Qwen2-VL module internally.