From dbf119764bd859579805c5f601f26a2e95174bfc Mon Sep 17 00:00:00 2001 From: Steven Shimizu Date: Fri, 27 Sep 2024 19:55:04 +0000 Subject: [PATCH 01/11] Fixed patching for Llama model --- src/liger_kernel/transformers/monkey_patch.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index b60e328fd..8996508de 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -81,22 +81,25 @@ def apply_liger_kernel_to_llama( # Direct LlamaModel base_model = model - torch_dtype = config.torch_dtype if rms_norm: - base_model.norm = LigerRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + base_model.offset = 0.0 + base_model.casting_mode = "llama" + base_model.norm.forward = LigerRMSNorm.forward + base_model.norm.extra_repr = LigerRMSNorm.extra_repr for decoder_layer in base_model.layers: if swiglu: - decoder_layer.mlp = LigerSwiGLUMLP(config).to(torch_dtype) + decoder_layer.mlp.forward = LigerSwiGLUMLP.forward if rms_norm: - decoder_layer.input_layernorm = LigerRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) - decoder_layer.post_attention_layernorm = LigerRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + decoder_layer.input_layernorm.offset = 0.0 + decoder_layer.input_layernorm.casting_mode = "llama" + decoder_layer.input_layernorm.forward = LigerRMSNorm.forward + decoder_layer.input_layernorm.extra_repr = LigerRMSNorm.extra_repr + + decoder_layer.post_attention_layernorm.offset = 0.0 + decoder_layer.post_attention_layernorm.casting_mode = "llama" + decoder_layer.post_attention_layernorm.forward = LigerRMSNorm.forward + decoder_layer.post_attention_layernorm.extra_repr = LigerRMSNorm.extra_repr def apply_liger_kernel_to_mistral( From 211702db4430b784bd6acd17b8ed6196fb0a7460 Mon Sep 17 00:00:00 2001 From: Steven Shimizu Date: Fri, 27 Sep 2024 22:59:28 +0000 Subject: [PATCH 02/11] Patched except for qwen2_vl --- src/liger_kernel/transformers/monkey_patch.py | 160 +++++------------- 1 file changed, 45 insertions(+), 115 deletions(-) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 8996508de..9203b38d7 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -24,6 +24,12 @@ logger = logging.getLogger(__name__) +def _patch_rms_norm_layer(norm_layer, offset=0.0, casting_mode="llama"): + norm_layer.offset = offset + norm_layer.casting_mode = casting_mode + norm_layer.forward = LigerRMSNorm.forward + norm_layer.extra_repr = LigerRMSNorm.extra_repr + def apply_liger_kernel_to_llama( rope: bool = True, @@ -69,7 +75,6 @@ def apply_liger_kernel_to_llama( if model is not None: # The model instance already exists, so we need to additionally patch the # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP) - config: PretrainedConfig = model.config if hasattr(model, "model"): # The case for LlamaForCausalLM or LlamaForSequenceClassification, for example @@ -82,24 +87,14 @@ def apply_liger_kernel_to_llama( base_model = model if rms_norm: - base_model.offset = 0.0 - base_model.casting_mode = "llama" - base_model.norm.forward = LigerRMSNorm.forward - base_model.norm.extra_repr = LigerRMSNorm.extra_repr + _patch_rms_norm_layer(base_model.norm) for decoder_layer in base_model.layers: if swiglu: decoder_layer.mlp.forward = LigerSwiGLUMLP.forward if rms_norm: - decoder_layer.input_layernorm.offset = 0.0 - decoder_layer.input_layernorm.casting_mode = "llama" - decoder_layer.input_layernorm.forward = LigerRMSNorm.forward - decoder_layer.input_layernorm.extra_repr = LigerRMSNorm.extra_repr - - decoder_layer.post_attention_layernorm.offset = 0.0 - decoder_layer.post_attention_layernorm.casting_mode = "llama" - decoder_layer.post_attention_layernorm.forward = LigerRMSNorm.forward - decoder_layer.post_attention_layernorm.extra_repr = LigerRMSNorm.extra_repr + _patch_rms_norm_layer(decoder_layer.input_layernorm) + _patch_rms_norm_layer(decoder_layer.post_attention_layernorm) def apply_liger_kernel_to_mistral( @@ -146,7 +141,6 @@ def apply_liger_kernel_to_mistral( if model is not None: # The model instance already exists, so we need to additionally patch the # instance variables that reference already-instantiated modules - config: PretrainedConfig = model.config if hasattr(model, "model"): # The case for MistralForCausalLM, MistralForTokenClassification for example @@ -155,22 +149,15 @@ def apply_liger_kernel_to_mistral( # Direct MistralModel base_model = model - torch_dtype = config.torch_dtype if rms_norm: - base_model.norm = LigerRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + _patch_rms_norm_layer(base_model.norm) for decoder_layer in base_model.layers: if swiglu: - decoder_layer.mlp = LigerSwiGLUMLP(config).to(torch_dtype) + decoder_layer.mlp.forward = LigerSwiGLUMLP.forward if rms_norm: - decoder_layer.input_layernorm = LigerRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) - decoder_layer.post_attention_layernorm = LigerRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + _patch_rms_norm_layer(decoder_layer.input_layernorm) + _patch_rms_norm_layer(decoder_layer.post_attention_layernorm) def apply_liger_kernel_to_mixtral( @@ -217,7 +204,6 @@ def apply_liger_kernel_to_mixtral( if model is not None: # The model instance already exists, so we need to additionally patch the # instance variables that reference already-instantiated modules - config: PretrainedConfig = model.config if hasattr(model, "model"): # The case for MixtralForCausalLM, MixtralForTokenClassification for example @@ -226,29 +212,16 @@ def apply_liger_kernel_to_mixtral( # Direct MixtralModel base_model = model - torch_dtype = config.torch_dtype if rms_norm: - base_model.norm = LigerRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + _patch_rms_norm_layer(base_model.norm) for decoder_layer in base_model.layers: if swiglu: - block_sparse_moe = decoder_layer.block_sparse_moe - patched_experts = nn.ModuleList( - [ - LigerBlockSparseTop2MLP(config) - for _ in range(block_sparse_moe.num_experts) - ] - ) - decoder_layer.block_sparse_moe.experts = patched_experts.to(torch_dtype) + for expert in decoder_layer.block_sparse_moe.experts: + expert.forward = LigerBlockSparseTop2MLP.forward if rms_norm: - decoder_layer.input_layernorm = LigerRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) - decoder_layer.post_attention_layernorm = LigerRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + _patch_rms_norm_layer(decoder_layer.input_layernorm) + _patch_rms_norm_layer(decoder_layer.post_attention_layernorm) def apply_liger_kernel_to_gemma( @@ -285,6 +258,7 @@ def apply_liger_kernel_to_gemma( LigerRMSNormForGemma = partial( LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma" ) + _patch_rms_norm_layer_for_gemma = partial(_patch_rms_norm_layer, casting_mode="gemma", offset=1.0) if rope: modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb @@ -300,7 +274,6 @@ def apply_liger_kernel_to_gemma( if model is not None: # The model instance already exists, so we need to additionally patch the # instance variables that reference already-instantiated modules - config: PretrainedConfig = model.config if hasattr(model, "model"): # The case for GemmaForCausalLM, GemmaForTokenClassification for example @@ -309,22 +282,15 @@ def apply_liger_kernel_to_gemma( # Direct GemmaModel base_model = model - torch_dtype = config.torch_dtype if rms_norm: - base_model.norm = LigerRMSNormForGemma( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + _patch_rms_norm_layer_for_gemma(base_model.norm) for decoder_layer in base_model.layers: if geglu: - decoder_layer.mlp = LigerGEGLUMLP(config).to(torch_dtype) + decoder_layer.mlp.forward = LigerGEGLUMLP.forward if rms_norm: - decoder_layer.input_layernorm = LigerRMSNormForGemma( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) - decoder_layer.post_attention_layernorm = LigerRMSNormForGemma( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + _patch_rms_norm_layer_for_gemma(decoder_layer.input_layernorm) + _patch_rms_norm_layer_for_gemma(decoder_layer.post_attention_layernorm) def apply_liger_kernel_to_gemma2( @@ -348,7 +314,9 @@ def apply_liger_kernel_to_gemma2( """ from transformers.models.gemma2 import modeling_gemma2 - LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, init_fn="zeros") + LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros") + _patch_rms_norm_layer_for_gemma2 = partial(_patch_rms_norm_layer, offset=1.0, casting_mode="gemma") + if rope: modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb if rms_norm: @@ -362,7 +330,6 @@ def apply_liger_kernel_to_gemma2( if model is not None: # The model instance already exists, so we need to additionally patch the # instance variables that reference already-instantiated modules - config: PretrainedConfig = model.config if hasattr(model, "model"): # The case for Gemma2ForCausalLM, Gemma2ForTokenClassification for example @@ -371,28 +338,17 @@ def apply_liger_kernel_to_gemma2( # Direct Gemma2Model base_model = model - torch_dtype = config.torch_dtype if rms_norm: - base_model.norm = LigerRMSNormForGemma2( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + _patch_rms_norm_layer_for_gemma2(base_model.norm) for decoder_layer in base_model.layers: if geglu: - decoder_layer.mlp = LigerGEGLUMLP(config).to(torch_dtype) + decoder_layer.mlp.forward = LigerGEGLUMLP.forward if rms_norm: - decoder_layer.input_layernorm = LigerRMSNormForGemma2( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) - decoder_layer.post_attention_layernorm = LigerRMSNormForGemma2( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) - decoder_layer.pre_feedforward_layernorm = LigerRMSNormForGemma2( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) - decoder_layer.post_feedforward_layernorm = LigerRMSNormForGemma2( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + _patch_rms_norm_layer_for_gemma2(decoder_layer.input_layernorm) + _patch_rms_norm_layer_for_gemma2(decoder_layer.post_attention_layernorm) + _patch_rms_norm_layer_for_gemma2(decoder_layer.pre_feedforward_layernorm) + _patch_rms_norm_layer_for_gemma2(decoder_layer.post_feedforward_layernorm) def apply_liger_kernel_to_qwen2( @@ -438,7 +394,6 @@ def apply_liger_kernel_to_qwen2( if model is not None: # The model instance already exists, so we need to additionally patch the # instance variables that reference already-instantiated modules - config: PretrainedConfig = model.config if hasattr(model, "model"): # The case for Qwen2ForCausalLM, Qwen2ForTokenClassification for example @@ -447,22 +402,15 @@ def apply_liger_kernel_to_qwen2( # Direct Qwen2Model base_model = model - torch_dtype = config.torch_dtype if rms_norm: - base_model.norm = LigerRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + _patch_rms_norm_layer(base_model.norm) for decoder_layer in base_model.layers: if swiglu: - decoder_layer.mlp = LigerSwiGLUMLP(config).to(torch_dtype) + decoder_layer.mlp.forward = LigerSwiGLUMLP.forward if rms_norm: - decoder_layer.input_layernorm = LigerRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) - decoder_layer.post_attention_layernorm = LigerRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + _patch_rms_norm_layer(decoder_layer.input_layernorm) + _patch_rms_norm_layer(decoder_layer.post_attention_layernorm) def apply_liger_kernel_to_qwen2_vl( @@ -501,10 +449,9 @@ def apply_liger_kernel_to_qwen2_vl( # TODO: Support Qwen2-VL's multimodal RoPE implementation - LigerRMSNormForQwen2VL = partial(LigerRMSNorm, init_fn="ones", casting_mode="gemma") if rms_norm: # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439 - modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNormForQwen2VL + modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm if layer_norm: modeling_qwen2_vl.LayerNorm = LigerLayerNorm if cross_entropy: @@ -517,9 +464,6 @@ def apply_liger_kernel_to_qwen2_vl( if model is not None: # The model instance already exists, so we need to additionally patch the # instance variables that reference already-instantiated modules - config: PretrainedConfig = model.config - - torch_dtype = config.torch_dtype if hasattr(model, "model"): # The case for Qwen2VLForConditionalGeneration. @@ -540,19 +484,13 @@ def apply_liger_kernel_to_qwen2_vl( ) if rms_norm: - base_model.norm = LigerRMSNormForQwen2VL( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + _patch_rms_norm_layer(base_model.norm) for decoder_layer in base_model.layers: if swiglu: - decoder_layer.mlp = LigerSwiGLUMLP(config).to(torch_dtype) + decoder_layer.mlp.forward = LigerSwiGLUMLP.forward if rms_norm: - decoder_layer.input_layernorm = LigerRMSNormForQwen2VL( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) - decoder_layer.post_attention_layernorm = LigerRMSNormForQwen2VL( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + _patch_rms_norm_layer(decoder_layer.input_layernorm) + _patch_rms_norm_layer(decoder_layer.post_attention_layernorm) def apply_liger_kernel_to_phi3( @@ -598,7 +536,6 @@ def apply_liger_kernel_to_phi3( if model is not None: # The model instance already exists, so we need to additionally patch the # instance variables that reference already-instantiated modules - config: PretrainedConfig = model.config if hasattr(model, "model"): # The case for Phi3ForCausalLM, Phi3ForTokenClassification for example @@ -607,22 +544,15 @@ def apply_liger_kernel_to_phi3( # Direct Phi3Model base_model = model - torch_dtype = config.torch_dtype if rms_norm: - base_model.norm = LigerRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + _patch_rms_norm_layer(base_model.norm) for decoder_layer in base_model.layers: if swiglu: - decoder_layer.mlp = LigerPhi3SwiGLUMLP(config).to(torch_dtype) + decoder_layer.mlp.forward = LigerPhi3SwiGLUMLP.forward if rms_norm: - decoder_layer.input_layernorm = LigerRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) - decoder_layer.post_attention_layernorm = LigerRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + _patch_rms_norm_layer(decoder_layer.input_layernorm) + _patch_rms_norm_layer(decoder_layer.post_attention_layernorm) # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py From 0ac4b36800491fb3fcc1127ae946a8c3aa97cb67 Mon Sep 17 00:00:00 2001 From: Steven Shimizu Date: Sat, 28 Sep 2024 07:50:11 +0000 Subject: [PATCH 03/11] Fixed all models --- src/liger_kernel/transformers/monkey_patch.py | 81 ++++++++++--------- 1 file changed, 42 insertions(+), 39 deletions(-) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 9203b38d7..f9d3a7eed 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -24,11 +24,18 @@ logger = logging.getLogger(__name__) -def _patch_rms_norm_layer(norm_layer, offset=0.0, casting_mode="llama"): - norm_layer.offset = offset - norm_layer.casting_mode = casting_mode - norm_layer.forward = LigerRMSNorm.forward - norm_layer.extra_repr = LigerRMSNorm.extra_repr +def _patch_rms_norm_module(module, offset=0.0, casting_mode="llama"): + module.offset = offset + module.casting_mode = casting_mode + module.forward = LigerRMSNorm.forward + module.extra_repr = LigerRMSNorm.extra_repr + +def _patch_layer_norm_module(module, eps=1e-6): + module.eps = eps + module.variance_epsilon = eps + module.hidden_size = module.normalized_shape + module.forward = LigerLayerNorm.forward + module.extra_repr = LigerLayerNorm.extra_repr def apply_liger_kernel_to_llama( @@ -87,14 +94,14 @@ def apply_liger_kernel_to_llama( base_model = model if rms_norm: - _patch_rms_norm_layer(base_model.norm) + _patch_rms_norm_module(base_model.norm) for decoder_layer in base_model.layers: if swiglu: decoder_layer.mlp.forward = LigerSwiGLUMLP.forward if rms_norm: - _patch_rms_norm_layer(decoder_layer.input_layernorm) - _patch_rms_norm_layer(decoder_layer.post_attention_layernorm) + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) def apply_liger_kernel_to_mistral( @@ -150,14 +157,14 @@ def apply_liger_kernel_to_mistral( base_model = model if rms_norm: - _patch_rms_norm_layer(base_model.norm) + _patch_rms_norm_module(base_model.norm) for decoder_layer in base_model.layers: if swiglu: decoder_layer.mlp.forward = LigerSwiGLUMLP.forward if rms_norm: - _patch_rms_norm_layer(decoder_layer.input_layernorm) - _patch_rms_norm_layer(decoder_layer.post_attention_layernorm) + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) def apply_liger_kernel_to_mixtral( @@ -213,15 +220,15 @@ def apply_liger_kernel_to_mixtral( base_model = model if rms_norm: - _patch_rms_norm_layer(base_model.norm) + _patch_rms_norm_module(base_model.norm) for decoder_layer in base_model.layers: if swiglu: for expert in decoder_layer.block_sparse_moe.experts: expert.forward = LigerBlockSparseTop2MLP.forward if rms_norm: - _patch_rms_norm_layer(decoder_layer.input_layernorm) - _patch_rms_norm_layer(decoder_layer.post_attention_layernorm) + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) def apply_liger_kernel_to_gemma( @@ -258,7 +265,7 @@ def apply_liger_kernel_to_gemma( LigerRMSNormForGemma = partial( LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma" ) - _patch_rms_norm_layer_for_gemma = partial(_patch_rms_norm_layer, casting_mode="gemma", offset=1.0) + _patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0) if rope: modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb @@ -283,14 +290,14 @@ def apply_liger_kernel_to_gemma( base_model = model if rms_norm: - _patch_rms_norm_layer_for_gemma(base_model.norm) + _patch_rms_norm_module_for_gemma(base_model.norm) for decoder_layer in base_model.layers: if geglu: decoder_layer.mlp.forward = LigerGEGLUMLP.forward if rms_norm: - _patch_rms_norm_layer_for_gemma(decoder_layer.input_layernorm) - _patch_rms_norm_layer_for_gemma(decoder_layer.post_attention_layernorm) + _patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm) + _patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm) def apply_liger_kernel_to_gemma2( @@ -315,7 +322,7 @@ def apply_liger_kernel_to_gemma2( from transformers.models.gemma2 import modeling_gemma2 LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros") - _patch_rms_norm_layer_for_gemma2 = partial(_patch_rms_norm_layer, offset=1.0, casting_mode="gemma") + _patch_rms_norm_module_for_gemma2 = partial(_patch_rms_norm_module, offset=1.0, casting_mode="gemma") if rope: modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb @@ -339,16 +346,16 @@ def apply_liger_kernel_to_gemma2( base_model = model if rms_norm: - _patch_rms_norm_layer_for_gemma2(base_model.norm) + _patch_rms_norm_module_for_gemma2(base_model.norm) for decoder_layer in base_model.layers: if geglu: decoder_layer.mlp.forward = LigerGEGLUMLP.forward if rms_norm: - _patch_rms_norm_layer_for_gemma2(decoder_layer.input_layernorm) - _patch_rms_norm_layer_for_gemma2(decoder_layer.post_attention_layernorm) - _patch_rms_norm_layer_for_gemma2(decoder_layer.pre_feedforward_layernorm) - _patch_rms_norm_layer_for_gemma2(decoder_layer.post_feedforward_layernorm) + _patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm) + _patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm) + _patch_rms_norm_module_for_gemma2(decoder_layer.pre_feedforward_layernorm) + _patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm) def apply_liger_kernel_to_qwen2( @@ -403,14 +410,14 @@ def apply_liger_kernel_to_qwen2( base_model = model if rms_norm: - _patch_rms_norm_layer(base_model.norm) + _patch_rms_norm_module(base_model.norm) for decoder_layer in base_model.layers: if swiglu: decoder_layer.mlp.forward = LigerSwiGLUMLP.forward if rms_norm: - _patch_rms_norm_layer(decoder_layer.input_layernorm) - _patch_rms_norm_layer(decoder_layer.post_attention_layernorm) + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) def apply_liger_kernel_to_qwen2_vl( @@ -476,21 +483,17 @@ def apply_liger_kernel_to_qwen2_vl( # Patch Qwen2VisionTransformerPretrainedModel for vision_block in model.visual.blocks: if layer_norm: - vision_block.norm1 = LigerLayerNorm(config.embed_dim, eps=1e-6).to( - torch_dtype - ) - vision_block.norm2 = LigerLayerNorm(config.embed_dim, eps=1e-6).to( - torch_dtype - ) + _patch_layer_norm_module(vision_block.norm1) + _patch_layer_norm_module(vision_block.norm2) if rms_norm: - _patch_rms_norm_layer(base_model.norm) + _patch_rms_norm_module(base_model.norm) for decoder_layer in base_model.layers: if swiglu: decoder_layer.mlp.forward = LigerSwiGLUMLP.forward if rms_norm: - _patch_rms_norm_layer(decoder_layer.input_layernorm) - _patch_rms_norm_layer(decoder_layer.post_attention_layernorm) + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) def apply_liger_kernel_to_phi3( @@ -545,14 +548,14 @@ def apply_liger_kernel_to_phi3( base_model = model if rms_norm: - _patch_rms_norm_layer(base_model.norm) + _patch_rms_norm_module(base_model.norm) for decoder_layer in base_model.layers: if swiglu: decoder_layer.mlp.forward = LigerPhi3SwiGLUMLP.forward if rms_norm: - _patch_rms_norm_layer(decoder_layer.input_layernorm) - _patch_rms_norm_layer(decoder_layer.post_attention_layernorm) + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py From 68a95b7f236f6d5017c0ac2c42cbabeaa3058c05 Mon Sep 17 00:00:00 2001 From: Steven Shimizu Date: Sat, 28 Sep 2024 08:18:27 +0000 Subject: [PATCH 04/11] Fixed monkey patch tests --- test/transformers/test_monkey_patch.py | 120 ++++++++++++------------- 1 file changed, 60 insertions(+), 60 deletions(-) diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index bdb6ee11e..88710cbd6 100644 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -212,21 +212,21 @@ def test_apply_liger_kernel_to_instance_for_llama(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert not isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert not dummy_model_instance.model.norm.forward == LigerRMSNorm.forward for layer in dummy_model_instance.model.layers: - assert not isinstance(layer.mlp, LigerSwiGLUMLP) - assert not isinstance(layer.input_layernorm, LigerRMSNorm) - assert not isinstance(layer.post_attention_layernorm, LigerRMSNorm) + assert not layer.mlp.forward == LigerSwiGLUMLP.forward + assert not layer.input_layernorm.forward == LigerRMSNorm.forward + assert not layer.post_attention_layernorm.forward == LigerRMSNorm.forward # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert dummy_model_instance.model.norm.forward == LigerRMSNorm.forward for layer in dummy_model_instance.model.layers: - assert isinstance(layer.mlp, LigerSwiGLUMLP) - assert isinstance(layer.input_layernorm, LigerRMSNorm) - assert isinstance(layer.post_attention_layernorm, LigerRMSNorm) + assert layer.mlp.forward == LigerSwiGLUMLP.forward + assert layer.input_layernorm.forward == LigerRMSNorm.forward + assert layer.post_attention_layernorm.forward == LigerRMSNorm.forward def test_apply_liger_kernel_to_instance_for_mistral(): @@ -245,21 +245,21 @@ def test_apply_liger_kernel_to_instance_for_mistral(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert not isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert not dummy_model_instance.model.norm.forward == LigerRMSNorm.forward for layer in dummy_model_instance.model.layers: - assert not isinstance(layer.mlp, LigerSwiGLUMLP) - assert not isinstance(layer.input_layernorm, LigerRMSNorm) - assert not isinstance(layer.post_attention_layernorm, LigerRMSNorm) + assert not layer.mlp.forward == LigerSwiGLUMLP.forward + assert not layer.input_layernorm.forward == LigerRMSNorm.forward + assert not layer.post_attention_layernorm.forward == LigerRMSNorm.forward # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert dummy_model_instance.model.norm.forward == LigerRMSNorm.forward for layer in dummy_model_instance.model.layers: - assert isinstance(layer.mlp, LigerSwiGLUMLP) - assert isinstance(layer.input_layernorm, LigerRMSNorm) - assert isinstance(layer.post_attention_layernorm, LigerRMSNorm) + assert layer.mlp.forward == LigerSwiGLUMLP.forward + assert layer.input_layernorm.forward == LigerRMSNorm.forward + assert layer.post_attention_layernorm.forward == LigerRMSNorm.forward def test_apply_liger_kernel_to_instance_for_mixtral(): @@ -280,23 +280,23 @@ def test_apply_liger_kernel_to_instance_for_mixtral(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert not isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert not dummy_model_instance.model.norm.forward == LigerRMSNorm.forward for layer in dummy_model_instance.model.layers: for expert in layer.block_sparse_moe.experts: - assert not isinstance(expert, LigerBlockSparseTop2MLP) - assert not isinstance(layer.input_layernorm, LigerRMSNorm) - assert not isinstance(layer.post_attention_layernorm, LigerRMSNorm) + assert not expert.forward == LigerBlockSparseTop2MLP.forward + assert not layer.input_layernorm.forward == LigerRMSNorm.forward + assert not layer.post_attention_layernorm.forward == LigerRMSNorm.forward # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert dummy_model_instance.model.norm.forward == LigerRMSNorm.forward for layer in dummy_model_instance.model.layers: for expert in layer.block_sparse_moe.experts: - assert isinstance(expert, LigerBlockSparseTop2MLP) - assert isinstance(layer.input_layernorm, LigerRMSNorm) - assert isinstance(layer.post_attention_layernorm, LigerRMSNorm) + assert expert.forward == LigerBlockSparseTop2MLP.forward + assert layer.input_layernorm.forward == LigerRMSNorm.forward + assert layer.post_attention_layernorm.forward == LigerRMSNorm.forward def test_apply_liger_kernel_to_instance_for_gemma(): @@ -315,21 +315,21 @@ def test_apply_liger_kernel_to_instance_for_gemma(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert not isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert not dummy_model_instance.model.norm.forward == LigerRMSNorm.forward for layer in dummy_model_instance.model.layers: - assert not isinstance(layer.mlp, LigerGEGLUMLP) - assert not isinstance(layer.input_layernorm, LigerRMSNorm) - assert not isinstance(layer.post_attention_layernorm, LigerRMSNorm) + assert not layer.mlp.forward == LigerGEGLUMLP.forward + assert not layer.input_layernorm.forward == LigerRMSNorm.forward + assert not layer.post_attention_layernorm.forward == LigerRMSNorm.forward # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert dummy_model_instance.model.norm.forward == LigerRMSNorm.forward for layer in dummy_model_instance.model.layers: - assert isinstance(layer.mlp, LigerGEGLUMLP) - assert isinstance(layer.input_layernorm, LigerRMSNorm) - assert isinstance(layer.post_attention_layernorm, LigerRMSNorm) + assert layer.mlp.forward == LigerGEGLUMLP.forward + assert layer.input_layernorm.forward == LigerRMSNorm.forward + assert layer.post_attention_layernorm.forward == LigerRMSNorm.forward def test_apply_liger_kernel_to_instance_for_gemma2(): @@ -348,25 +348,25 @@ def test_apply_liger_kernel_to_instance_for_gemma2(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert not isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert not dummy_model_instance.model.norm.forward == LigerRMSNorm.forward for layer in dummy_model_instance.model.layers: - assert not isinstance(layer.mlp, LigerGEGLUMLP) - assert not isinstance(layer.input_layernorm, LigerRMSNorm) - assert not isinstance(layer.post_attention_layernorm, LigerRMSNorm) - assert not isinstance(layer.pre_feedforward_layernorm, LigerRMSNorm) - assert not isinstance(layer.post_feedforward_layernorm, LigerRMSNorm) + assert not layer.mlp.forward == LigerGEGLUMLP.forward + assert not layer.input_layernorm.forward == LigerRMSNorm.forward + assert not layer.post_attention_layernorm.forward == LigerRMSNorm.forward + assert not layer.pre_feedforward_layernorm.forward == LigerRMSNorm.forward + assert not layer.post_feedforward_layernorm.forward == LigerRMSNorm.forward # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert dummy_model_instance.model.norm.forward == LigerRMSNorm.forward for layer in dummy_model_instance.model.layers: - assert isinstance(layer.mlp, LigerGEGLUMLP) - assert isinstance(layer.input_layernorm, LigerRMSNorm) - assert isinstance(layer.post_attention_layernorm, LigerRMSNorm) - assert isinstance(layer.pre_feedforward_layernorm, LigerRMSNorm) - assert isinstance(layer.post_feedforward_layernorm, LigerRMSNorm) + assert layer.mlp.forward == LigerGEGLUMLP.forward + assert layer.input_layernorm.forward == LigerRMSNorm.forward + assert layer.post_attention_layernorm.forward == LigerRMSNorm.forward + assert layer.pre_feedforward_layernorm.forward == LigerRMSNorm.forward + assert layer.post_feedforward_layernorm.forward == LigerRMSNorm.forward def test_apply_liger_kernel_to_instance_for_qwen2(): @@ -385,21 +385,21 @@ def test_apply_liger_kernel_to_instance_for_qwen2(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert not isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert not dummy_model_instance.model.norm.forward == LigerRMSNorm.forward for layer in dummy_model_instance.model.layers: - assert not isinstance(layer.mlp, LigerSwiGLUMLP) - assert not isinstance(layer.input_layernorm, LigerRMSNorm) - assert not isinstance(layer.post_attention_layernorm, LigerRMSNorm) + assert not layer.mlp.forward == LigerSwiGLUMLP.forward + assert not layer.input_layernorm.forward == LigerRMSNorm.forward + assert not layer.post_attention_layernorm.forward == LigerRMSNorm.forward # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert dummy_model_instance.model.norm.forward == LigerRMSNorm.forward for layer in dummy_model_instance.model.layers: - assert isinstance(layer.mlp, LigerSwiGLUMLP) - assert isinstance(layer.input_layernorm, LigerRMSNorm) - assert isinstance(layer.post_attention_layernorm, LigerRMSNorm) + assert layer.mlp.forward == LigerSwiGLUMLP.forward + assert layer.input_layernorm.forward == LigerRMSNorm.forward + assert layer.post_attention_layernorm.forward == LigerRMSNorm.forward def test_apply_liger_kernel_to_instance_for_phi3(): @@ -418,18 +418,18 @@ def test_apply_liger_kernel_to_instance_for_phi3(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert not isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert not dummy_model_instance.model.norm.forward == LigerRMSNorm.forward for layer in dummy_model_instance.model.layers: - assert not isinstance(layer.mlp, LigerPhi3SwiGLUMLP) - assert not isinstance(layer.input_layernorm, LigerRMSNorm) - assert not isinstance(layer.post_attention_layernorm, LigerRMSNorm) + assert not layer.mlp.forward == LigerPhi3SwiGLUMLP.forward + assert not layer.input_layernorm.forward == LigerRMSNorm.forward + assert not layer.post_attention_layernorm.forward == LigerRMSNorm.forward # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert dummy_model_instance.model.norm.forward == LigerRMSNorm.forward for layer in dummy_model_instance.model.layers: - assert isinstance(layer.mlp, LigerPhi3SwiGLUMLP) - assert isinstance(layer.input_layernorm, LigerRMSNorm) - assert isinstance(layer.post_attention_layernorm, LigerRMSNorm) + assert layer.mlp.forward == LigerPhi3SwiGLUMLP.forward + assert layer.input_layernorm.forward == LigerRMSNorm.forward + assert layer.post_attention_layernorm.forward == LigerRMSNorm.forward From c93cb2a24f2c71aecacd5673dafa75e3fa239dcf Mon Sep 17 00:00:00 2001 From: Steven Shimizu Date: Sat, 28 Sep 2024 16:10:27 +0000 Subject: [PATCH 05/11] Tested pre/post init in convergence tests --- test/convergence/test_mini_models.py | 60 +++++++++++++++++++++++----- 1 file changed, 51 insertions(+), 9 deletions(-) diff --git a/test/convergence/test_mini_models.py b/test/convergence/test_mini_models.py index f648a88c2..75d45f326 100644 --- a/test/convergence/test_mini_models.py +++ b/test/convergence/test_mini_models.py @@ -313,10 +313,11 @@ def run_mini_model( dtype=torch.bfloat16, lr=1e-5, with_liger=False, + post_init_patching=False, ): # If we move it to the beginning of test_mini_model, the two runs are initialized with different weights. # This is due to RNG (Random Number Generator). The formula of RNG progression is x_(n+1) = (a * x_n + c) % m - # Everytime RNG is used, like randomly initialzing weight, the RNG progresses to the next state. + # Everytime RNG is used, like randomly initializing weight, the RNG progresses to the next state. # Therefore, we have to reset RNG before we create the model to ensure the weight initialization started from the same RNG state. set_seed(42) @@ -331,11 +332,18 @@ def run_mini_model( kwargs["geglu"] = True else: kwargs["swiglu"] = True + + # Make sure any patches have been reverted before tests + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func() + + if post_init_patching: + model = create_model(model_name).to(dtype).to("cuda") + kwargs["model"] = model MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) else: - MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func() + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) + model = create_model(model_name).to(dtype).to("cuda") - model = create_model(model_name).to(dtype).to("cuda") train_dataset = load_from_disk(DEFAULT_DATASET_PATH) loader = DataLoader( @@ -483,6 +491,7 @@ def run_mini_model( ) def test_mini_model( model_name, + post_init_patching, num_steps, lr, dtype, @@ -496,17 +505,49 @@ def test_mini_model( # Non-liger models should be initialized and tested first to avoid the module being overridden expected_output = run_mini_model( - model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr + model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=False + ) + + actual_output_pre = run_mini_model( + model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True, post_init_patching=False, + ) + + actual_output_post = run_mini_model( + model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True, post_init_patching=True, + ) + + ### Pre-init patching + # Compare the loss of every step + assert_verbose_allclose( + torch.tensor([expected_output["loss"]]), + torch.tensor([actual_output_pre["loss"]]), + atol=loss_atol, + rtol=loss_rtol, ) - actual_output = run_mini_model( - model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True + # Compare the logits from the last step + assert_verbose_allclose( + expected_output["logits"], + actual_output_pre["logits"], + atol=logits_atol, + rtol=logits_rtol, ) + # Compare the params from the last step + # Iterate over the model's parameters and compare them + for expected_param, actual_param in zip( + expected_output["model"].named_parameters(), + actual_output_post["model"].named_parameters(), + ): + assert_verbose_allclose( + expected_param[1], actual_param[1], atol=param_atol, rtol=param_rtol + ) + + ### Post-init patching # Compare the loss of every step assert_verbose_allclose( torch.tensor([expected_output["loss"]]), - torch.tensor([actual_output["loss"]]), + torch.tensor([actual_output_post["loss"]]), atol=loss_atol, rtol=loss_rtol, ) @@ -514,7 +555,7 @@ def test_mini_model( # Compare the logits from the last step assert_verbose_allclose( expected_output["logits"], - actual_output["logits"], + actual_output_post["logits"], atol=logits_atol, rtol=logits_rtol, ) @@ -523,8 +564,9 @@ def test_mini_model( # Iterate over the model's parameters and compare them for expected_param, actual_param in zip( expected_output["model"].named_parameters(), - actual_output["model"].named_parameters(), + actual_output_post["model"].named_parameters(), ): assert_verbose_allclose( expected_param[1], actual_param[1], atol=param_atol, rtol=param_rtol ) + From 0d5c19185b0449d44cc1dfa25cccf189492fe2e7 Mon Sep 17 00:00:00 2001 From: Steven Shimizu Date: Sat, 28 Sep 2024 22:27:07 +0000 Subject: [PATCH 06/11] Convergence tests --- src/liger_kernel/transformers/monkey_patch.py | 31 +++++++++++-------- test/convergence/test_mini_models.py | 19 ++++++------ 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index f9d3a7eed..a721005d4 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -1,6 +1,7 @@ import inspect import logging from functools import partial +from typing import Callable from torch import nn from transformers import PretrainedConfig, PreTrainedModel @@ -24,18 +25,22 @@ logger = logging.getLogger(__name__) -def _patch_rms_norm_module(module, offset=0.0, casting_mode="llama"): +def _bind_method_to_module(module, method_name: str, new_method: Callable): + module.__dict__[method_name] = new_method.__get__(module, module.__class__) + +def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama"): module.offset = offset module.casting_mode = casting_mode - module.forward = LigerRMSNorm.forward - module.extra_repr = LigerRMSNorm.extra_repr + module.variance_epsilon = eps + _bind_method_to_module(module, "forward", LigerRMSNorm.forward) + _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr) def _patch_layer_norm_module(module, eps=1e-6): module.eps = eps module.variance_epsilon = eps module.hidden_size = module.normalized_shape - module.forward = LigerLayerNorm.forward - module.extra_repr = LigerLayerNorm.extra_repr + _bind_method_to_module(module, "forward", LigerLayerNorm.forward) + _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr) def apply_liger_kernel_to_llama( @@ -98,7 +103,7 @@ def apply_liger_kernel_to_llama( for decoder_layer in base_model.layers: if swiglu: - decoder_layer.mlp.forward = LigerSwiGLUMLP.forward + _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward) if rms_norm: _patch_rms_norm_module(decoder_layer.input_layernorm) _patch_rms_norm_module(decoder_layer.post_attention_layernorm) @@ -161,7 +166,7 @@ def apply_liger_kernel_to_mistral( for decoder_layer in base_model.layers: if swiglu: - decoder_layer.mlp.forward = LigerSwiGLUMLP.forward + _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward) if rms_norm: _patch_rms_norm_module(decoder_layer.input_layernorm) _patch_rms_norm_module(decoder_layer.post_attention_layernorm) @@ -225,7 +230,7 @@ def apply_liger_kernel_to_mixtral( for decoder_layer in base_model.layers: if swiglu: for expert in decoder_layer.block_sparse_moe.experts: - expert.forward = LigerBlockSparseTop2MLP.forward + _bind_method_to_module(expert, "forward", LigerBlockSparseTop2MLP.forward) if rms_norm: _patch_rms_norm_module(decoder_layer.input_layernorm) _patch_rms_norm_module(decoder_layer.post_attention_layernorm) @@ -294,7 +299,7 @@ def apply_liger_kernel_to_gemma( for decoder_layer in base_model.layers: if geglu: - decoder_layer.mlp.forward = LigerGEGLUMLP.forward + _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward) if rms_norm: _patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm) _patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm) @@ -350,7 +355,7 @@ def apply_liger_kernel_to_gemma2( for decoder_layer in base_model.layers: if geglu: - decoder_layer.mlp.forward = LigerGEGLUMLP.forward + _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward) if rms_norm: _patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm) _patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm) @@ -414,7 +419,7 @@ def apply_liger_kernel_to_qwen2( for decoder_layer in base_model.layers: if swiglu: - decoder_layer.mlp.forward = LigerSwiGLUMLP.forward + _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward) if rms_norm: _patch_rms_norm_module(decoder_layer.input_layernorm) _patch_rms_norm_module(decoder_layer.post_attention_layernorm) @@ -490,7 +495,7 @@ def apply_liger_kernel_to_qwen2_vl( _patch_rms_norm_module(base_model.norm) for decoder_layer in base_model.layers: if swiglu: - decoder_layer.mlp.forward = LigerSwiGLUMLP.forward + _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward) if rms_norm: _patch_rms_norm_module(decoder_layer.input_layernorm) _patch_rms_norm_module(decoder_layer.post_attention_layernorm) @@ -552,7 +557,7 @@ def apply_liger_kernel_to_phi3( for decoder_layer in base_model.layers: if swiglu: - decoder_layer.mlp.forward = LigerPhi3SwiGLUMLP.forward + _bind_method_to_module(decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward) if rms_norm: _patch_rms_norm_module(decoder_layer.input_layernorm) _patch_rms_norm_module(decoder_layer.post_attention_layernorm) diff --git a/test/convergence/test_mini_models.py b/test/convergence/test_mini_models.py index 75d45f326..4aaa0cd6d 100644 --- a/test/convergence/test_mini_models.py +++ b/test/convergence/test_mini_models.py @@ -322,6 +322,9 @@ def run_mini_model( set_seed(42) + # Make sure any patches have been reverted before tests + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func() + if with_liger is True: kwargs = { "rope": True, @@ -333,15 +336,14 @@ def run_mini_model( else: kwargs["swiglu"] = True - # Make sure any patches have been reverted before tests - MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func() - - if post_init_patching: - model = create_model(model_name).to(dtype).to("cuda") - kwargs["model"] = model - MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) + if post_init_patching: + model = create_model(model_name).to(dtype).to("cuda") + kwargs["model"] = model + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) + else: + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) + model = create_model(model_name).to(dtype).to("cuda") else: - MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) model = create_model(model_name).to(dtype).to("cuda") train_dataset = load_from_disk(DEFAULT_DATASET_PATH) @@ -491,7 +493,6 @@ def run_mini_model( ) def test_mini_model( model_name, - post_init_patching, num_steps, lr, dtype, From bbbd4dbf69f5c7bab043f1bbf80195c53405efe0 Mon Sep 17 00:00:00 2001 From: Steven Shimizu Date: Sat, 28 Sep 2024 22:32:48 +0000 Subject: [PATCH 07/11] Fixed checkstyle --- src/liger_kernel/transformers/monkey_patch.py | 62 ++++++++++++++----- test/convergence/test_mini_models.py | 25 +++++--- 2 files changed, 64 insertions(+), 23 deletions(-) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index a721005d4..408eafa47 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -3,8 +3,7 @@ from functools import partial from typing import Callable -from torch import nn -from transformers import PretrainedConfig, PreTrainedModel +from transformers import PreTrainedModel from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.geglu import LigerGEGLUMLP @@ -25,9 +24,11 @@ logger = logging.getLogger(__name__) + def _bind_method_to_module(module, method_name: str, new_method: Callable): module.__dict__[method_name] = new_method.__get__(module, module.__class__) + def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama"): module.offset = offset module.casting_mode = casting_mode @@ -35,6 +36,7 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama"): _bind_method_to_module(module, "forward", LigerRMSNorm.forward) _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr) + def _patch_layer_norm_module(module, eps=1e-6): module.eps = eps module.variance_epsilon = eps @@ -103,7 +105,9 @@ def apply_liger_kernel_to_llama( for decoder_layer in base_model.layers: if swiglu: - _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward) + _bind_method_to_module( + decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward + ) if rms_norm: _patch_rms_norm_module(decoder_layer.input_layernorm) _patch_rms_norm_module(decoder_layer.post_attention_layernorm) @@ -166,7 +170,9 @@ def apply_liger_kernel_to_mistral( for decoder_layer in base_model.layers: if swiglu: - _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward) + _bind_method_to_module( + decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward + ) if rms_norm: _patch_rms_norm_module(decoder_layer.input_layernorm) _patch_rms_norm_module(decoder_layer.post_attention_layernorm) @@ -230,7 +236,9 @@ def apply_liger_kernel_to_mixtral( for decoder_layer in base_model.layers: if swiglu: for expert in decoder_layer.block_sparse_moe.experts: - _bind_method_to_module(expert, "forward", LigerBlockSparseTop2MLP.forward) + _bind_method_to_module( + expert, "forward", LigerBlockSparseTop2MLP.forward + ) if rms_norm: _patch_rms_norm_module(decoder_layer.input_layernorm) _patch_rms_norm_module(decoder_layer.post_attention_layernorm) @@ -270,7 +278,9 @@ def apply_liger_kernel_to_gemma( LigerRMSNormForGemma = partial( LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma" ) - _patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0) + _patch_rms_norm_module_for_gemma = partial( + _patch_rms_norm_module, casting_mode="gemma", offset=1.0 + ) if rope: modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb @@ -299,7 +309,9 @@ def apply_liger_kernel_to_gemma( for decoder_layer in base_model.layers: if geglu: - _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward) + _bind_method_to_module( + decoder_layer.mlp, "forward", LigerGEGLUMLP.forward + ) if rms_norm: _patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm) _patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm) @@ -326,8 +338,12 @@ def apply_liger_kernel_to_gemma2( """ from transformers.models.gemma2 import modeling_gemma2 - LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros") - _patch_rms_norm_module_for_gemma2 = partial(_patch_rms_norm_module, offset=1.0, casting_mode="gemma") + LigerRMSNormForGemma2 = partial( + LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros" + ) + _patch_rms_norm_module_for_gemma2 = partial( + _patch_rms_norm_module, offset=1.0, casting_mode="gemma" + ) if rope: modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb @@ -355,12 +371,20 @@ def apply_liger_kernel_to_gemma2( for decoder_layer in base_model.layers: if geglu: - _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward) + _bind_method_to_module( + decoder_layer.mlp, "forward", LigerGEGLUMLP.forward + ) if rms_norm: _patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm) - _patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm) - _patch_rms_norm_module_for_gemma2(decoder_layer.pre_feedforward_layernorm) - _patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm) + _patch_rms_norm_module_for_gemma2( + decoder_layer.post_attention_layernorm + ) + _patch_rms_norm_module_for_gemma2( + decoder_layer.pre_feedforward_layernorm + ) + _patch_rms_norm_module_for_gemma2( + decoder_layer.post_feedforward_layernorm + ) def apply_liger_kernel_to_qwen2( @@ -419,7 +443,9 @@ def apply_liger_kernel_to_qwen2( for decoder_layer in base_model.layers: if swiglu: - _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward) + _bind_method_to_module( + decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward + ) if rms_norm: _patch_rms_norm_module(decoder_layer.input_layernorm) _patch_rms_norm_module(decoder_layer.post_attention_layernorm) @@ -495,7 +521,9 @@ def apply_liger_kernel_to_qwen2_vl( _patch_rms_norm_module(base_model.norm) for decoder_layer in base_model.layers: if swiglu: - _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward) + _bind_method_to_module( + decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward + ) if rms_norm: _patch_rms_norm_module(decoder_layer.input_layernorm) _patch_rms_norm_module(decoder_layer.post_attention_layernorm) @@ -557,7 +585,9 @@ def apply_liger_kernel_to_phi3( for decoder_layer in base_model.layers: if swiglu: - _bind_method_to_module(decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward) + _bind_method_to_module( + decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward + ) if rms_norm: _patch_rms_norm_module(decoder_layer.input_layernorm) _patch_rms_norm_module(decoder_layer.post_attention_layernorm) diff --git a/test/convergence/test_mini_models.py b/test/convergence/test_mini_models.py index 4aaa0cd6d..f32431f59 100644 --- a/test/convergence/test_mini_models.py +++ b/test/convergence/test_mini_models.py @@ -324,7 +324,7 @@ def run_mini_model( # Make sure any patches have been reverted before tests MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func() - + if with_liger is True: kwargs = { "rope": True, @@ -510,14 +510,25 @@ def test_mini_model( ) actual_output_pre = run_mini_model( - model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True, post_init_patching=False, + model_name=model_name, + num_steps=num_steps, + dtype=dtype, + lr=lr, + with_liger=True, + post_init_patching=False, ) actual_output_post = run_mini_model( - model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True, post_init_patching=True, + model_name=model_name, + num_steps=num_steps, + dtype=dtype, + lr=lr, + with_liger=True, + post_init_patching=True, ) - ### Pre-init patching + # Pre-init patching + # Compare the loss of every step assert_verbose_allclose( torch.tensor([expected_output["loss"]]), @@ -538,13 +549,14 @@ def test_mini_model( # Iterate over the model's parameters and compare them for expected_param, actual_param in zip( expected_output["model"].named_parameters(), - actual_output_post["model"].named_parameters(), + actual_output_pre["model"].named_parameters(), ): assert_verbose_allclose( expected_param[1], actual_param[1], atol=param_atol, rtol=param_rtol ) - ### Post-init patching + # Post-init patching + # Compare the loss of every step assert_verbose_allclose( torch.tensor([expected_output["loss"]]), @@ -570,4 +582,3 @@ def test_mini_model( assert_verbose_allclose( expected_param[1], actual_param[1], atol=param_atol, rtol=param_rtol ) - From 89667620fe7b8731bae5a54790d03de0341d8756 Mon Sep 17 00:00:00 2001 From: Steven Shimizu Date: Sat, 28 Sep 2024 22:39:05 +0000 Subject: [PATCH 08/11] Added comment --- src/liger_kernel/transformers/monkey_patch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 408eafa47..8b8907b52 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -26,6 +26,7 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable): + # Binds a new method to a module instance so that self is passed as the first argument module.__dict__[method_name] = new_method.__get__(module, module.__class__) From 2b6317eb72a11d9de50df3518e1ae11fa0882e06 Mon Sep 17 00:00:00 2001 From: Steven Shimizu Date: Mon, 30 Sep 2024 19:01:47 +0000 Subject: [PATCH 09/11] Fixed unit tests --- test/transformers/test_monkey_patch.py | 120 ++++++++++++------------- 1 file changed, 60 insertions(+), 60 deletions(-) diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index 88710cbd6..54455ae5a 100644 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -212,21 +212,21 @@ def test_apply_liger_kernel_to_instance_for_llama(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert not dummy_model_instance.model.norm.forward == LigerRMSNorm.forward + assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert not layer.mlp.forward == LigerSwiGLUMLP.forward - assert not layer.input_layernorm.forward == LigerRMSNorm.forward - assert not layer.post_attention_layernorm.forward == LigerRMSNorm.forward + assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert dummy_model_instance.model.norm.forward == LigerRMSNorm.forward + assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert layer.mlp.forward == LigerSwiGLUMLP.forward - assert layer.input_layernorm.forward == LigerRMSNorm.forward - assert layer.post_attention_layernorm.forward == LigerRMSNorm.forward + assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) def test_apply_liger_kernel_to_instance_for_mistral(): @@ -245,21 +245,21 @@ def test_apply_liger_kernel_to_instance_for_mistral(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert not dummy_model_instance.model.norm.forward == LigerRMSNorm.forward + assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert not layer.mlp.forward == LigerSwiGLUMLP.forward - assert not layer.input_layernorm.forward == LigerRMSNorm.forward - assert not layer.post_attention_layernorm.forward == LigerRMSNorm.forward + assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert dummy_model_instance.model.norm.forward == LigerRMSNorm.forward + assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert layer.mlp.forward == LigerSwiGLUMLP.forward - assert layer.input_layernorm.forward == LigerRMSNorm.forward - assert layer.post_attention_layernorm.forward == LigerRMSNorm.forward + assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) def test_apply_liger_kernel_to_instance_for_mixtral(): @@ -280,23 +280,23 @@ def test_apply_liger_kernel_to_instance_for_mixtral(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert not dummy_model_instance.model.norm.forward == LigerRMSNorm.forward + assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: for expert in layer.block_sparse_moe.experts: - assert not expert.forward == LigerBlockSparseTop2MLP.forward - assert not layer.input_layernorm.forward == LigerRMSNorm.forward - assert not layer.post_attention_layernorm.forward == LigerRMSNorm.forward + assert inspect.getsource(expert.forward) != inspect.getsource(LigerBlockSparseTop2MLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert dummy_model_instance.model.norm.forward == LigerRMSNorm.forward + assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: for expert in layer.block_sparse_moe.experts: - assert expert.forward == LigerBlockSparseTop2MLP.forward - assert layer.input_layernorm.forward == LigerRMSNorm.forward - assert layer.post_attention_layernorm.forward == LigerRMSNorm.forward + assert inspect.getsource(expert.forward) == inspect.getsource(LigerBlockSparseTop2MLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) def test_apply_liger_kernel_to_instance_for_gemma(): @@ -315,21 +315,21 @@ def test_apply_liger_kernel_to_instance_for_gemma(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert not dummy_model_instance.model.norm.forward == LigerRMSNorm.forward + assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert not layer.mlp.forward == LigerGEGLUMLP.forward - assert not layer.input_layernorm.forward == LigerRMSNorm.forward - assert not layer.post_attention_layernorm.forward == LigerRMSNorm.forward + assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerGEGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert dummy_model_instance.model.norm.forward == LigerRMSNorm.forward + assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert layer.mlp.forward == LigerGEGLUMLP.forward - assert layer.input_layernorm.forward == LigerRMSNorm.forward - assert layer.post_attention_layernorm.forward == LigerRMSNorm.forward + assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerGEGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) def test_apply_liger_kernel_to_instance_for_gemma2(): @@ -348,25 +348,25 @@ def test_apply_liger_kernel_to_instance_for_gemma2(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert not dummy_model_instance.model.norm.forward == LigerRMSNorm.forward + assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert not layer.mlp.forward == LigerGEGLUMLP.forward - assert not layer.input_layernorm.forward == LigerRMSNorm.forward - assert not layer.post_attention_layernorm.forward == LigerRMSNorm.forward - assert not layer.pre_feedforward_layernorm.forward == LigerRMSNorm.forward - assert not layer.post_feedforward_layernorm.forward == LigerRMSNorm.forward + assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerGEGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.pre_feedforward_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_feedforward_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert dummy_model_instance.model.norm.forward == LigerRMSNorm.forward + assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert layer.mlp.forward == LigerGEGLUMLP.forward - assert layer.input_layernorm.forward == LigerRMSNorm.forward - assert layer.post_attention_layernorm.forward == LigerRMSNorm.forward - assert layer.pre_feedforward_layernorm.forward == LigerRMSNorm.forward - assert layer.post_feedforward_layernorm.forward == LigerRMSNorm.forward + assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerGEGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.pre_feedforward_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_feedforward_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) def test_apply_liger_kernel_to_instance_for_qwen2(): @@ -385,21 +385,21 @@ def test_apply_liger_kernel_to_instance_for_qwen2(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert not dummy_model_instance.model.norm.forward == LigerRMSNorm.forward + assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert not layer.mlp.forward == LigerSwiGLUMLP.forward - assert not layer.input_layernorm.forward == LigerRMSNorm.forward - assert not layer.post_attention_layernorm.forward == LigerRMSNorm.forward + assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert dummy_model_instance.model.norm.forward == LigerRMSNorm.forward + assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert layer.mlp.forward == LigerSwiGLUMLP.forward - assert layer.input_layernorm.forward == LigerRMSNorm.forward - assert layer.post_attention_layernorm.forward == LigerRMSNorm.forward + assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) def test_apply_liger_kernel_to_instance_for_phi3(): @@ -418,18 +418,18 @@ def test_apply_liger_kernel_to_instance_for_phi3(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert not dummy_model_instance.model.norm.forward == LigerRMSNorm.forward + assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert not layer.mlp.forward == LigerPhi3SwiGLUMLP.forward - assert not layer.input_layernorm.forward == LigerRMSNorm.forward - assert not layer.post_attention_layernorm.forward == LigerRMSNorm.forward + assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerPhi3SwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert dummy_model_instance.model.norm.forward == LigerRMSNorm.forward + assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert layer.mlp.forward == LigerPhi3SwiGLUMLP.forward - assert layer.input_layernorm.forward == LigerRMSNorm.forward - assert layer.post_attention_layernorm.forward == LigerRMSNorm.forward + assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerPhi3SwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) \ No newline at end of file From 2c29e229b9b77a21560ba8844f1cfe7aa7a50929 Mon Sep 17 00:00:00 2001 From: Steven Shimizu Date: Mon, 30 Sep 2024 19:48:06 +0000 Subject: [PATCH 10/11] Fixed checkstyle --- test/transformers/test_monkey_patch.py | 240 ++++++++++++++++++------- 1 file changed, 180 insertions(+), 60 deletions(-) diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index 54455ae5a..0b5928619 100644 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -212,21 +212,37 @@ def test_apply_liger_kernel_to_instance_for_llama(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) - assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) != inspect.getsource( + LigerSwiGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) - assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) == inspect.getsource( + LigerSwiGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) def test_apply_liger_kernel_to_instance_for_mistral(): @@ -245,21 +261,37 @@ def test_apply_liger_kernel_to_instance_for_mistral(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) - assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) != inspect.getsource( + LigerSwiGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) - assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) == inspect.getsource( + LigerSwiGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) def test_apply_liger_kernel_to_instance_for_mixtral(): @@ -280,23 +312,39 @@ def test_apply_liger_kernel_to_instance_for_mixtral(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: for expert in layer.block_sparse_moe.experts: - assert inspect.getsource(expert.forward) != inspect.getsource(LigerBlockSparseTop2MLP.forward) - assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(expert.forward) != inspect.getsource( + LigerBlockSparseTop2MLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: for expert in layer.block_sparse_moe.experts: - assert inspect.getsource(expert.forward) == inspect.getsource(LigerBlockSparseTop2MLP.forward) - assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(expert.forward) == inspect.getsource( + LigerBlockSparseTop2MLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) def test_apply_liger_kernel_to_instance_for_gemma(): @@ -315,21 +363,37 @@ def test_apply_liger_kernel_to_instance_for_gemma(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerGEGLUMLP.forward) - assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) != inspect.getsource( + LigerGEGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerGEGLUMLP.forward) - assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) == inspect.getsource( + LigerGEGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) def test_apply_liger_kernel_to_instance_for_gemma2(): @@ -348,25 +412,49 @@ def test_apply_liger_kernel_to_instance_for_gemma2(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerGEGLUMLP.forward) - assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource(layer.pre_feedforward_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource(layer.post_feedforward_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) != inspect.getsource( + LigerGEGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.pre_feedforward_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_feedforward_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerGEGLUMLP.forward) - assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource(layer.pre_feedforward_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource(layer.post_feedforward_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) == inspect.getsource( + LigerGEGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.pre_feedforward_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_feedforward_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) def test_apply_liger_kernel_to_instance_for_qwen2(): @@ -385,21 +473,37 @@ def test_apply_liger_kernel_to_instance_for_qwen2(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) - assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) != inspect.getsource( + LigerSwiGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) - assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) == inspect.getsource( + LigerSwiGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) def test_apply_liger_kernel_to_instance_for_phi3(): @@ -418,18 +522,34 @@ def test_apply_liger_kernel_to_instance_for_phi3(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerPhi3SwiGLUMLP.forward) - assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.mlp.forward) != inspect.getsource( + LigerPhi3SwiGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerPhi3SwiGLUMLP.forward) - assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) \ No newline at end of file + assert inspect.getsource(layer.mlp.forward) == inspect.getsource( + LigerPhi3SwiGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) From 34967f2e04497e25bcf6f59056392423a16e1359 Mon Sep 17 00:00:00 2001 From: Steven Shimizu Date: Mon, 30 Sep 2024 20:54:51 +0000 Subject: [PATCH 11/11] Fixed tests --- src/liger_kernel/transformers/monkey_patch.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 8b8907b52..fa2090ced 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -33,14 +33,17 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable): def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama"): module.offset = offset module.casting_mode = casting_mode - module.variance_epsilon = eps + module.variance_epsilon = ( + getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps + ) _bind_method_to_module(module, "forward", LigerRMSNorm.forward) _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr) def _patch_layer_norm_module(module, eps=1e-6): - module.eps = eps - module.variance_epsilon = eps + module.variance_epsilon = ( + getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps + ) module.hidden_size = module.normalized_shape _bind_method_to_module(module, "forward", LigerLayerNorm.forward) _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)