diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index d1b211238..be286903e 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -121,6 +121,7 @@ def apply_liger_kernel_to_mllama( rope: bool = True, cross_entropy: bool = False, fused_linear_cross_entropy: bool = True, + layer_norm: bool = True, rms_norm: bool = True, swiglu: bool = True, model: PreTrainedModel = None, @@ -151,12 +152,15 @@ def apply_liger_kernel_to_mllama( MllamaForCausalLM, MllamaForConditionalGeneration, MllamaTextModel, + MllamaVisionModel, ) from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward if rope: modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb + if layer_norm: + modeling_mllama.nn.LayerNorm = LigerLayerNorm if rms_norm: modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm if swiglu: @@ -174,11 +178,14 @@ def apply_liger_kernel_to_mllama( if isinstance(model, MllamaForConditionalGeneration): language_model: MllamaForCausalLM = model.language_model + vision_model: MllamaVisionModel = model.vision_model text_model: MllamaTextModel = language_model.model elif isinstance(model, MllamaForCausalLM): text_model = model.model + vision_model = None elif isinstance(model, MllamaTextModel): text_model = model + vision_model = None else: raise ValueError(f"Unsupported Mllama model type: {type(model)}") @@ -194,6 +201,20 @@ def apply_liger_kernel_to_mllama( _patch_rms_norm_module(decoder_layer.input_layernorm) _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + if vision_model: + _patch_layer_norm_module(vision_model.layernorm_pre) + _patch_layer_norm_module(vision_model.layernorm_post) + + for layer in vision_model.transformer.layers: + if layer_norm: + _patch_layer_norm_module(layer.input_layernorm) + _patch_layer_norm_module(layer.post_attention_layernorm) + + for layer in vision_model.global_transformer.layers: + if layer_norm: + _patch_layer_norm_module(layer.input_layernorm) + _patch_layer_norm_module(layer.post_attention_layernorm) + def apply_liger_kernel_to_mistral( rope: bool = True, @@ -767,7 +788,6 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None: for key, value in kwargs.items() if key in apply_fn_signature.parameters } - logger.info( f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}" ) diff --git a/test/convergence/test_mini_models_multimodal.py b/test/convergence/test_mini_models_multimodal.py index 5bbc22294..c835df05d 100644 --- a/test/convergence/test_mini_models_multimodal.py +++ b/test/convergence/test_mini_models_multimodal.py @@ -316,15 +316,12 @@ def run_mini_model_multimodal( kwargs = { "rms_norm": True, "cross_entropy": True, + "layer_norm": True, } model_supports_rope = "qwen2_vl" not in model_name if model_supports_rope: kwargs["rope"] = True - model_supports_layer_norm = "qwen2_vl" in model_name - if model_supports_layer_norm: - kwargs["layer_norm"] = True - if "gemma" in model_name: kwargs["geglu"] = True else: diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index 355cc9096..7ce1aacb7 100644 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -302,6 +302,27 @@ def test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation(): layer.post_attention_layernorm.forward ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + dummy_model_instance.vision_model.layernorm_pre.forward + ) != inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource( + dummy_model_instance.vision_model.layernorm_post.forward + ) != inspect.getsource(LigerLayerNorm.forward) + for layer in dummy_model_instance.vision_model.transformer.layers: + assert inspect.getsource( + layer.input_layernorm.forward + ) != inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) != inspect.getsource(LigerLayerNorm.forward) + for layer in dummy_model_instance.vision_model.global_transformer.layers: + assert inspect.getsource( + layer.input_layernorm.forward + ) != inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) != inspect.getsource(LigerLayerNorm.forward) + # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) @@ -320,6 +341,27 @@ def test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation(): layer.post_attention_layernorm.forward ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + dummy_model_instance.vision_model.layernorm_pre.forward + ) == inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource( + dummy_model_instance.vision_model.layernorm_post.forward + ) == inspect.getsource(LigerLayerNorm.forward) + for layer in dummy_model_instance.vision_model.transformer.layers: + assert inspect.getsource( + layer.input_layernorm.forward + ) == inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) == inspect.getsource(LigerLayerNorm.forward) + for layer in dummy_model_instance.vision_model.global_transformer.layers: + assert inspect.getsource( + layer.input_layernorm.forward + ) == inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) == inspect.getsource(LigerLayerNorm.forward) + def test_apply_liger_kernel_to_instance_for_mllama_for_causal_lm(): # Ensure any monkey patching is cleaned up for subsequent tests diff --git a/test/utils.py b/test/utils.py index 5341cf563..ac9a13190 100644 --- a/test/utils.py +++ b/test/utils.py @@ -222,8 +222,10 @@ def revert_liger_kernel_to_mllama(): Revert all Liger kernel patches applied to MLlama. """ + import torch.nn as nn from transformers.models.mllama import modeling_mllama + importlib.reload(nn) importlib.reload(modeling_mllama) print("Liger kernel patches have been reverted.")