From 713fa4108c24c10de11ef3c0471ed3431802e814 Mon Sep 17 00:00:00 2001 From: ByronHsu Date: Tue, 12 Nov 2024 19:34:15 +0000 Subject: [PATCH 1/4] support CE after grad acc fix --- .github/workflows/ci.yml | 25 + Makefile | 5 +- dev/modal/tests_bwd.py | 28 + src/liger_kernel/transformers/functional.py | 33 +- src/liger_kernel/transformers/monkey_patch.py | 54 +- .../test_mini_models_with_logits.py | 705 ++++++++++++++++++ test/transformers/test_cross_entropy.py | 11 +- 7 files changed, 850 insertions(+), 11 deletions(-) create mode 100644 dev/modal/tests_bwd.py create mode 100644 test/convergence/test_mini_models_with_logits.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 15a7db41a..86ec2b581 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -64,3 +64,28 @@ jobs: - name: Run unit tests run: | modal run dev.modal.tests + + tests-bwd: + runs-on: ubuntu-latest + needs: [checkstyle] + env: + MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} + MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install modal + + - name: Run unit tests + run: | + modal run dev.modal.tests_bwd \ No newline at end of file diff --git a/Makefile b/Makefile index f0120bd21..00b677d3e 100644 --- a/Makefile +++ b/Makefile @@ -20,8 +20,9 @@ checkstyle: # Command to run pytest for convergence tests # We have to explicitly set HF_DATASETS_OFFLINE=1, or dataset will silently try to send metrics and timeout (80s) https://github.com/huggingface/datasets/blob/37a603679f451826cfafd8aae00738b01dcb9d58/src/datasets/load.py#L286 test-convergence: - HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence - + HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models.py + HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models_multimodal.py + HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models_with_logits.py # Command to run all benchmark scripts and update benchmarking data file # By default this doesn't overwrite existing data for the same benchmark experiment diff --git a/dev/modal/tests_bwd.py b/dev/modal/tests_bwd.py new file mode 100644 index 000000000..183d6eccf --- /dev/null +++ b/dev/modal/tests_bwd.py @@ -0,0 +1,28 @@ +from pathlib import Path + +import modal + +ROOT_PATH = Path(__file__).parent.parent.parent + +# tests_bwd is to ensure the backward compatibility of liger with older transformers +image = ( + modal.Image.debian_slim() + .pip_install_from_pyproject( + ROOT_PATH / "pyproject.toml", optional_dependencies=["dev"] + ) + .pip_install("transformers=4.44.2") +) + +app = modal.App("liger_tests", image=image) + +# mount: add local files to the remote container +repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path="/root/liger-kernel") + + +@app.function(gpu="A10G", mounts=[repo], timeout=60 * 10) +def liger_tests(): + import subprocess + + subprocess.run(["pip", "install", "-e", "."], check=True, cwd="/root/liger-kernel") + subprocess.run(["make", "test"], check=True, cwd="/root/liger-kernel") + subprocess.run(["make", "test-convergence"], check=True, cwd="/root/liger-kernel") diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index 292c0dba7..6a040b51b 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -1,3 +1,5 @@ +from typing import Optional + from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction from liger_kernel.ops.fused_linear_cross_entropy import ( LigerFusedLinearCrossEntropyFunction, @@ -13,7 +15,6 @@ from liger_kernel.ops.swiglu import LigerSiLUMulFunction liger_swiglu = LigerSiLUMulFunction.apply -liger_cross_entropy = LigerCrossEntropyFunction.apply liger_fused_linear_cross_entropy = LigerFusedLinearCrossEntropyFunction.apply liger_geglu = LigerGELUMulFunction.apply liger_rms_norm = LigerRMSNormFunction.apply @@ -23,3 +24,33 @@ liger_jsd = LigerJSDFunction.apply liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply liger_group_norm = LigerGroupNormFunction.apply + + +# conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html +# `weight` and `size_average` are placeholders and not implemented yet +def liger_cross_entropy( + input, + target, + weight=None, + size_average=None, + ignore_index: int = -100, + reduce=None, + reduction: str = "mean", + label_smoothing: float = 0.0, + lse_square_scale: float = 0.0, + softcap: Optional[float] = None, + return_z_loss: bool = False, +): + loss, z_loss = LigerCrossEntropyFunction.apply( + input, + target, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap, + return_z_loss, + ) + if not return_z_loss: + return loss + return loss, z_loss diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index fb1a8db91..8bfeb524e 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -8,6 +8,7 @@ from transformers import PreTrainedModel from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss +from liger_kernel.transformers.functional import liger_cross_entropy from liger_kernel.transformers.geglu import LigerGEGLUMLP from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward @@ -102,6 +103,7 @@ def apply_liger_kernel_to_llama( cross_entropy and fused_linear_cross_entropy ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + from transformers.loss.loss_utils import nn from transformers.models.llama import modeling_llama from transformers.models.llama.modeling_llama import LlamaModel @@ -111,8 +113,14 @@ def apply_liger_kernel_to_llama( modeling_llama.LlamaRMSNorm = LigerRMSNorm if swiglu: modeling_llama.LlamaMLP = LigerSwiGLUMLP + if cross_entropy: - modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss + if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): + nn.functional.cross_entropy = liger_cross_entropy + else: + logger.warning(TRANSFORMER_DEPRECATION_WARNING) + modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss + if fused_linear_cross_entropy: if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): modeling_llama.LlamaForCausalLM.forward = llama_lce_forward @@ -170,6 +178,7 @@ def apply_liger_kernel_to_mllama( cross_entropy and fused_linear_cross_entropy ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + from transformers.loss.loss_utils import nn from transformers.models.mllama import modeling_mllama from transformers.models.mllama.modeling_mllama import ( MllamaForCausalLM, @@ -192,7 +201,11 @@ def apply_liger_kernel_to_mllama( if swiglu: modeling_mllama.MllamaTextMLP = LigerSwiGLUMLP if cross_entropy: - modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss + if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): + nn.functional.cross_entropy = liger_cross_entropy + else: + logger.warning(TRANSFORMER_DEPRECATION_WARNING) + modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss if fused_linear_cross_entropy: if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward @@ -334,6 +347,7 @@ def apply_liger_kernel_to_mixtral( cross_entropy and fused_linear_cross_entropy ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + from transformers.loss.loss_utils import nn from transformers.models.mixtral import modeling_mixtral from transformers.models.mixtral.modeling_mixtral import MixtralModel @@ -342,7 +356,12 @@ def apply_liger_kernel_to_mixtral( if rms_norm: modeling_mixtral.MixtralRMSNorm = LigerRMSNorm if cross_entropy: - modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss + if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): + nn.functional.cross_entropy = liger_cross_entropy + else: + logger.warning(TRANSFORMER_DEPRECATION_WARNING) + modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss + if fused_linear_cross_entropy: if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward @@ -401,6 +420,7 @@ def apply_liger_kernel_to_gemma( cross_entropy and fused_linear_cross_entropy ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + from transformers.loss.loss_utils import nn from transformers.models.gemma import modeling_gemma from transformers.models.gemma.modeling_gemma import GemmaModel @@ -417,7 +437,11 @@ def apply_liger_kernel_to_gemma( if rms_norm: modeling_gemma.GemmaRMSNorm = LigerRMSNormForGemma if cross_entropy: - modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss + if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): + nn.functional.cross_entropy = liger_cross_entropy + else: + logger.warning(TRANSFORMER_DEPRECATION_WARNING) + modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss if geglu: modeling_gemma.GemmaMLP = LigerGEGLUMLP if fused_linear_cross_entropy: @@ -474,6 +498,7 @@ def apply_liger_kernel_to_gemma2( assert not ( cross_entropy and fused_linear_cross_entropy ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + from transformers.loss.loss_utils import nn from transformers.models.gemma2 import modeling_gemma2 from transformers.models.gemma2.modeling_gemma2 import Gemma2Model @@ -490,7 +515,11 @@ def apply_liger_kernel_to_gemma2( # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109 modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2 if cross_entropy: - modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss + if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): + nn.functional.cross_entropy = liger_cross_entropy + else: + logger.warning(TRANSFORMER_DEPRECATION_WARNING) + modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss if fused_linear_cross_entropy: if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward @@ -555,6 +584,7 @@ def apply_liger_kernel_to_qwen2( cross_entropy and fused_linear_cross_entropy ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + from transformers.loss.loss_utils import nn from transformers.models.qwen2 import modeling_qwen2 from transformers.models.qwen2.modeling_qwen2 import Qwen2Model @@ -562,8 +592,13 @@ def apply_liger_kernel_to_qwen2( modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb if rms_norm: modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm + if cross_entropy: - modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss + if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): + nn.functional.cross_entropy = liger_cross_entropy + else: + logger.warning(TRANSFORMER_DEPRECATION_WARNING) + modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss # import pdb; pdb.set_trace() if fused_linear_cross_entropy: @@ -700,6 +735,7 @@ def apply_liger_kernel_to_phi3( cross_entropy and fused_linear_cross_entropy ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + from transformers.loss.loss_utils import nn from transformers.models.phi3 import modeling_phi3 from transformers.models.phi3.modeling_phi3 import Phi3Model @@ -710,7 +746,11 @@ def apply_liger_kernel_to_phi3( if swiglu: modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP if cross_entropy: - modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss + if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): + nn.functional.cross_entropy = liger_cross_entropy + else: + logger.warning(TRANSFORMER_DEPRECATION_WARNING) + modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss if fused_linear_cross_entropy: if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward diff --git a/test/convergence/test_mini_models_with_logits.py b/test/convergence/test_mini_models_with_logits.py new file mode 100644 index 000000000..d142e4903 --- /dev/null +++ b/test/convergence/test_mini_models_with_logits.py @@ -0,0 +1,705 @@ +from test.utils import ( + DEFAULT_DATASET_PATH, + MiniModelConfig, + assert_verbose_allclose, + revert_liger_kernel_to_gemma, + revert_liger_kernel_to_gemma2, + revert_liger_kernel_to_llama, + revert_liger_kernel_to_mistral, + revert_liger_kernel_to_mixtral, + revert_liger_kernel_to_mllama, + revert_liger_kernel_to_phi3, + revert_liger_kernel_to_qwen2, + revert_liger_kernel_to_qwen2_vl, + set_seed, + simple_collate_fn, +) + +import pytest +import torch +from datasets import load_from_disk +from torch.utils.data import DataLoader +from transformers.models.gemma import GemmaConfig, GemmaForCausalLM +from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM +from transformers.models.llama import LlamaConfig, LlamaForCausalLM +from transformers.models.mistral import MistralConfig, MistralForCausalLM +from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM +from transformers.models.phi3 import Phi3Config, Phi3ForCausalLM +from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM + +from liger_kernel.transformers import ( + apply_liger_kernel_to_gemma, + apply_liger_kernel_to_gemma2, + apply_liger_kernel_to_llama, + apply_liger_kernel_to_mistral, + apply_liger_kernel_to_mixtral, + apply_liger_kernel_to_mllama, + apply_liger_kernel_to_phi3, + apply_liger_kernel_to_qwen2, + apply_liger_kernel_to_qwen2_vl, +) + +try: + # Mllama is only available in transformers>=4.45.0 + from transformers.models.mllama.configuration_mllama import MllamaTextConfig + from transformers.models.mllama.modeling_mllama import MllamaForCausalLM + + MLLAMA_AVAILABLE = True +except ImportError: + MLLAMA_AVAILABLE = False + +try: + # Qwen2-VL is only available in transformers>4.44.2 + from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig + from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2VLForConditionalGeneration, + ) + + QWEN2_VL_AVAILABLE = True +except ImportError: + QWEN2_VL_AVAILABLE = False + +MINI_MODEL_SETUPS = { + "mini_llama3": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_llama, + liger_kernel_patch_revert_func=revert_liger_kernel_to_llama, + model_class=LlamaForCausalLM, + mini_model_config=LlamaConfig( + attention_bias=False, + attention_dropout=0.0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=8192, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=2, # 8 + pretraining_tp=1, + rms_norm_eps=1e-5, + rope_scaling=None, + rope_theta=500000.0, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ), + "mini_qwen2": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen2, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2, + model_class=Qwen2ForCausalLM, + mini_model_config=Qwen2Config( + attention_dropout=0.0, + bos_token_id=1, # 151643 + eos_token_id=2, # 151643 + hidden_act="silu", + hidden_size=896, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=32768, # 131072 + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-6, + rope_theta=1000000.0, + sliding_window=131072, + tie_word_embeddings=True, + use_cache=True, + vocab_size=32000, # 151936 + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ), + "mini_phi3": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_phi3, + liger_kernel_patch_revert_func=revert_liger_kernel_to_phi3, + model_class=Phi3ForCausalLM, + mini_model_config=Phi3Config( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, # 32000 + hidden_act="silu", + hidden_size=896, # 3072 + initializer_range=0.02, + intermediate_size=4864, # 8192 + max_position_embeddings=4096, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=None, # defaults to num_attention_heads + rms_norm_eps=1e-5, + rope_theta=10000.0, + sliding_window=None, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32064, + attn_implementation="eager", + ), + ), + "mini_mistral": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_mistral, + liger_kernel_patch_revert_func=revert_liger_kernel_to_mistral, + model_class=MistralForCausalLM, + mini_model_config=MistralConfig( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=2048, + max_position_embeddings=32768, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-5, + rope_theta=10000.0, + sliding_window=4096, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, + attn_implementation="sdpa", + ), + ), + "mini_mixtral": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_mixtral, + liger_kernel_patch_revert_func=revert_liger_kernel_to_mixtral, + model_class=MixtralForCausalLM, + mini_model_config=MixtralConfig( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=512, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=32768, # 32768 + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-5, + rope_theta=10000.0, + sliding_window=4096, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, + attn_implementation="sdpa", + ), + ), + "mini_gemma1": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_gemma, + liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma, + model_class=GemmaForCausalLM, + mini_model_config=GemmaConfig( + vocab_size=32000, # 256000 + hidden_size=1024, # 3072 + intermediate_size=2048, # 24576 + num_hidden_layers=4, # 28 + num_attention_heads=4, # 16 + num_key_value_heads=4, # 16 + head_dim=256, + # gemma1 model config uses `hidden_act` and point it to gelu, + # https://huggingface.co/google/gemma-7b/blob/main/config.json#L10 + # but in reality it's ignored and HuggingFace will use tanh approximation: + # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/gemma/modeling_gemma.py#L175 + hidden_act="gelu", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + tie_word_embeddings=True, + rope_theta=10000.0, + attention_bias=False, + attention_dropout=0.0, + ), + ), + "mini_gemma1.1": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_gemma, + liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma, + model_class=GemmaForCausalLM, + mini_model_config=GemmaConfig( + vocab_size=32000, # 256000 + hidden_size=1024, # 3072 + intermediate_size=2048, # 24576 + num_hidden_layers=4, # 28 + num_attention_heads=4, # 16 + num_key_value_heads=4, # 16 + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + tie_word_embeddings=True, + rope_theta=10000.0, + attention_bias=False, + attention_dropout=0.0, + ), + ), + "mini_gemma2": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_gemma2, + liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma2, + model_class=Gemma2ForCausalLM, + mini_model_config=Gemma2Config( + vocab_size=32000, # 256000 + hidden_size=1024, # 3072 + intermediate_size=2048, # 24576 + num_hidden_layers=4, # 28 + num_attention_heads=4, # 16 + num_key_value_heads=4, # 16 + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + tie_word_embeddings=True, + rope_theta=10000.0, + attention_bias=False, + attention_dropout=0.0, + attn_implementation="eager", + ), + ), +} + +if MLLAMA_AVAILABLE: + MINI_MODEL_SETUPS["mini_mllama"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_mllama, + liger_kernel_patch_revert_func=revert_liger_kernel_to_mllama, + model_class=MllamaForCausalLM, + mini_model_config=MllamaTextConfig( + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + pad_token_id=2, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=131_072, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 40 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-5, + rope_scaling=dict( + factor=8.0, + high_freq_factor=4.0, + low_freq_factor=1.0, + original_max_position_embeddings=8192, + rope_type="llama3", + ), + rope_theta=500_000, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if QWEN2_VL_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen2_vl"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen2_vl, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2_vl, + model_class=Qwen2VLForConditionalGeneration, + mini_model_config=Qwen2VLConfig( + attention_dropout=0.0, + bos_token_id=1, # 151643 + eos_token_id=2, # 151645 + hidden_act="silu", + hidden_size=1536, # 8192 + initializer_range=0.02, + intermediate_size=4864, # 29568 + max_position_embeddings=32768, + max_window_layers=4, # 80 + num_attention_heads=12, # 64 + num_hidden_layers=4, # 80 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-6, # 1e-5 + rope_theta=1000000.0, + rope_scaling=dict( + type="mrope", + mrope_section=[16, 24, 24], # (temporal, height, width) + ), + sliding_window=4096, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 152064 + use_sliding_window=False, + vision_config={ + "depth": 4, # 32 + "embed_dim": 1280, + "mlp_ratio": 4, + "num_heads": 16, + "in_chans": 3, + "hidden_size": 128, # 1536 + "patch_size": 14, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "temporal_patch_size": 2, + }, + attn_implementation="sdpa", + ), + ) + + +def create_model(model_name="mini_llama3"): + """ + Create a mini version model + The commented values are the original values + """ + model_config = MINI_MODEL_SETUPS[model_name].mini_model_config + model_class = MINI_MODEL_SETUPS[model_name].model_class + return model_class(model_config) + + +def run_mini_model( + model_name="mini_llama3", + num_steps=100, + dtype=torch.bfloat16, + lr=1e-5, + with_liger=False, +): + # If we move it to the beginning of test_mini_model, the two runs are initialized with different weights. + # This is due to RNG (Random Number Generator). The formula of RNG progression is x_(n+1) = (a * x_n + c) % m + # Everytime RNG is used, like randomly initialzing weight, the RNG progresses to the next state. + # Therefore, we have to reset RNG before we create the model to ensure the weight initialization started from the same RNG state. + + set_seed(42) + + if with_liger is True: + kwargs = { + "rms_norm": True, + } + model_supports_rope = "qwen2_vl" not in model_name + if model_supports_rope: + kwargs["rope"] = True + + model_supports_layer_norm = "qwen2_vl" in model_name + if model_supports_layer_norm: + kwargs["layer_norm"] = True + + if "gemma" in model_name: + kwargs["geglu"] = True + else: + kwargs["swiglu"] = True + + kwargs["fused_linear_cross_entropy"] = False + kwargs["cross_entropy"] = True + + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) + else: + ... + # FIXME: disable revert because it will cause flce to not be patched + # MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func() + + model = create_model(model_name).to(dtype).to("cuda") + train_dataset = load_from_disk(DEFAULT_DATASET_PATH) + loader = DataLoader( + train_dataset, batch_size=16, shuffle=False, collate_fn=simple_collate_fn + ) + loader_iter = iter(loader) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr) + + loss_list = [] + + for i in range(num_steps): + batch = next(loader_iter).to(model.device) + optimizer.zero_grad() + output = model(**batch) + output.loss.backward() + optimizer.step() + print(f"Step {i}, Loss: {output.loss.item()}") + loss_list.append(output.loss.item()) + + # MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func() + return {"loss": loss_list, "logits": output.logits, "model": model} + + +@pytest.mark.parametrize( + # FIXME enable bf16 tests after revert is fixed + "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol", + [ + ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5), + # pytest.param( + # "mini_llama3", + # 32, + # 1e-4, + # torch.bfloat16, + # 1e-3, + # 1e-2, + # 1e-1, + # 1e-2, + # 1e-2, + # 1e-2, + # marks=pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # ), + pytest.param( + "mini_mllama", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not MLLAMA_AVAILABLE, + reason="Mllama not available in this version of transformers", + ), + ), + # pytest.param( + # "mini_mllama", + # 32, + # 1e-4, + # torch.bfloat16, + # 1e-3, + # 1e-2, + # 1e-1, + # 1e-2, + # 1e-2, + # 1e-2, + # marks=[ + # pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # pytest.mark.skipif( + # not MLLAMA_AVAILABLE, + # reason="Mllama not available in this version of transformers", + # ), + # ], + # ), + ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + # pytest.param( + # "mini_qwen2", + # 32, + # 1e-4, + # torch.bfloat16, + # 1e-3, + # 1e-2, + # 1e-1, + # 1e-2, + # 1e-2, + # 1e-2, + # marks=pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # ), + # FIXME qwen2 is broken and needs fix + # pytest.param( + # "mini_qwen2_vl", + # 32, + # 1e-4, + # torch.float32, + # 1e-8, + # 1e-5, + # 5e-3, + # 1e-5, + # 5e-3, + # 1e-5, + # marks=pytest.mark.skipif( + # not QWEN2_VL_AVAILABLE, + # reason="Qwen2-VL not available in this version of transformers", + # ), + # ), + # pytest.param( + # "mini_qwen2_vl", + # 32, + # 1e-4, + # torch.bfloat16, + # 1e-3, + # 1e-2, + # 1e-1, + # 1e-2, + # 1e-2, + # 1e-2, + # marks=[ + # pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # pytest.mark.skipif( + # not QWEN2_VL_AVAILABLE, + # reason="Qwen2-VL not available in this version of transformers", + # ), + # ], + # ), + ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + # pytest.param( + # "mini_phi3", + # 32, + # 1e-4, + # torch.bfloat16, + # 1e-3, + # 1e-2, + # 1e-1, + # 1e-2, + # 1e-2, + # 1e-2, + # marks=pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # ), + ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + # pytest.param( + # "mini_mistral", + # 32, + # 1e-4, + # torch.bfloat16, + # 1e-3, + # 1e-2, + # 1e-1, + # 1e-2, + # 1e-2, + # 1e-2, + # marks=pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # ), + # TODO: mixtral is flaky so disable the test for now + ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5), + # pytest.param( + # "mini_mixtral", + # 32, + # 1e-4, + # torch.bfloat16, + # 1e-3, + # 1e-2, + # 1e-1, + # 1e-2, + # 1e-1, + # 1e-2, + # marks=pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # ), + # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match + ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), + # pytest.param( + # "mini_gemma1", + # 32, + # 1e-4, + # torch.bfloat16, + # 1e-3, + # 1e-2, + # 1e-1, + # 1e-2, + # 1e-2, + # 1e-2, + # marks=pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # ), + ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), + # pytest.param( + # "mini_gemma1.1", + # 32, + # 1e-4, + # torch.bfloat16, + # 1e-3, + # 1e-2, + # 1e-1, + # 1e-2, + # 1e-2, + # 1e-2, + # marks=pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # ), + # TODO: Gemma2 tests are not passing within the tolerance range, need to investigate + # ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), + # pytest.param( + # "mini_gemma2", + # 32, + # 1e-4, + # torch.bfloat16, + # 1e-3, + # 1e-2, + # 1e-1, + # 1e-2, + # 1e-2, + # 1e-2, + # marks=pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # ), + ], +) +def test_mini_model( + model_name, + num_steps, + lr, + dtype, + loss_atol, + loss_rtol, + logits_atol, + logits_rtol, + param_atol, + param_rtol, +): + # Non-liger models should be initialized and tested first to avoid the module being overridden + + expected_output = run_mini_model( + model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr + ) + + actual_output = run_mini_model( + model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True + ) + + # Compare every step of the loss + assert_verbose_allclose( + torch.tensor([expected_output["loss"]]), + torch.tensor([actual_output["loss"]]), + atol=loss_atol, + rtol=loss_rtol, + ) + + # No logits are materialized + # import pdb; pdb.set_trace() + # Compare the logits from the last step + assert_verbose_allclose( + expected_output["logits"], + actual_output["logits"], + atol=logits_atol, + rtol=logits_rtol, + ) + + # Compare the params from the last step + # Iterate over the model's parameters and compare them + for expected_param, actual_param in zip( + expected_output["model"].named_parameters(), + actual_output["model"].named_parameters(), + ): + assert_verbose_allclose( + expected_param[1], actual_param[1], atol=param_atol, rtol=param_rtol + ) diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index a505e6fcd..6ec73a1a3 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -337,7 +337,16 @@ def _test_correctness_functional( target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) - y1, y1_z = liger_cross_entropy(x1, target, 0, 1e-4, 0.1, "mean", 30.0, True) + y1, y1_z = liger_cross_entropy( + x1, + target, + ignore_index=0, + lse_square_scale=1e-4, + label_smoothing=0.1, + reduction="mean", + softcap=30.0, + return_z_loss=True, + ) y2, y2_z = LigerCrossEntropyFunction.apply( x2, target, 0, 1e-4, 0.1, "mean", 30.0, True ) From a146592a7ffb7f333b98c0db90b90ede813e0990 Mon Sep 17 00:00:00 2001 From: ByronHsu Date: Tue, 12 Nov 2024 19:39:10 +0000 Subject: [PATCH 2/4] fix modal code --- dev/modal/tests_bwd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/modal/tests_bwd.py b/dev/modal/tests_bwd.py index 183d6eccf..13b7c59ad 100644 --- a/dev/modal/tests_bwd.py +++ b/dev/modal/tests_bwd.py @@ -10,7 +10,7 @@ .pip_install_from_pyproject( ROOT_PATH / "pyproject.toml", optional_dependencies=["dev"] ) - .pip_install("transformers=4.44.2") + .pip_install("transformers==4.44.2") ) app = modal.App("liger_tests", image=image) From b5d3bc3e7cb5a4304bc8d35b52228098bb4cafe9 Mon Sep 17 00:00:00 2001 From: ByronHsu Date: Tue, 12 Nov 2024 20:10:25 +0000 Subject: [PATCH 3/4] fix backward comp --- src/liger_kernel/transformers/monkey_patch.py | 22 +++++++++++++------ .../test_mini_models_with_logits.py | 2 +- test/transformers/test_monkey_patch.py | 22 +++++++++++++++++++ 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 8bfeb524e..df622118e 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -103,7 +103,6 @@ def apply_liger_kernel_to_llama( cross_entropy and fused_linear_cross_entropy ), "cross_entropy and fused_linear_cross_entropy cannot both be True." - from transformers.loss.loss_utils import nn from transformers.models.llama import modeling_llama from transformers.models.llama.modeling_llama import LlamaModel @@ -116,6 +115,8 @@ def apply_liger_kernel_to_llama( if cross_entropy: if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): + from transformers.loss.loss_utils import nn + nn.functional.cross_entropy = liger_cross_entropy else: logger.warning(TRANSFORMER_DEPRECATION_WARNING) @@ -178,7 +179,6 @@ def apply_liger_kernel_to_mllama( cross_entropy and fused_linear_cross_entropy ), "cross_entropy and fused_linear_cross_entropy cannot both be True." - from transformers.loss.loss_utils import nn from transformers.models.mllama import modeling_mllama from transformers.models.mllama.modeling_mllama import ( MllamaForCausalLM, @@ -202,6 +202,8 @@ def apply_liger_kernel_to_mllama( modeling_mllama.MllamaTextMLP = LigerSwiGLUMLP if cross_entropy: if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): + from transformers.loss.loss_utils import nn + nn.functional.cross_entropy = liger_cross_entropy else: logger.warning(TRANSFORMER_DEPRECATION_WARNING) @@ -347,7 +349,6 @@ def apply_liger_kernel_to_mixtral( cross_entropy and fused_linear_cross_entropy ), "cross_entropy and fused_linear_cross_entropy cannot both be True." - from transformers.loss.loss_utils import nn from transformers.models.mixtral import modeling_mixtral from transformers.models.mixtral.modeling_mixtral import MixtralModel @@ -357,6 +358,8 @@ def apply_liger_kernel_to_mixtral( modeling_mixtral.MixtralRMSNorm = LigerRMSNorm if cross_entropy: if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): + from transformers.loss.loss_utils import nn + nn.functional.cross_entropy = liger_cross_entropy else: logger.warning(TRANSFORMER_DEPRECATION_WARNING) @@ -420,7 +423,6 @@ def apply_liger_kernel_to_gemma( cross_entropy and fused_linear_cross_entropy ), "cross_entropy and fused_linear_cross_entropy cannot both be True." - from transformers.loss.loss_utils import nn from transformers.models.gemma import modeling_gemma from transformers.models.gemma.modeling_gemma import GemmaModel @@ -438,6 +440,8 @@ def apply_liger_kernel_to_gemma( modeling_gemma.GemmaRMSNorm = LigerRMSNormForGemma if cross_entropy: if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): + from transformers.loss.loss_utils import nn + nn.functional.cross_entropy = liger_cross_entropy else: logger.warning(TRANSFORMER_DEPRECATION_WARNING) @@ -498,7 +502,7 @@ def apply_liger_kernel_to_gemma2( assert not ( cross_entropy and fused_linear_cross_entropy ), "cross_entropy and fused_linear_cross_entropy cannot both be True." - from transformers.loss.loss_utils import nn + from transformers.models.gemma2 import modeling_gemma2 from transformers.models.gemma2.modeling_gemma2 import Gemma2Model @@ -516,6 +520,8 @@ def apply_liger_kernel_to_gemma2( modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2 if cross_entropy: if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): + from transformers.loss.loss_utils import nn + nn.functional.cross_entropy = liger_cross_entropy else: logger.warning(TRANSFORMER_DEPRECATION_WARNING) @@ -584,7 +590,6 @@ def apply_liger_kernel_to_qwen2( cross_entropy and fused_linear_cross_entropy ), "cross_entropy and fused_linear_cross_entropy cannot both be True." - from transformers.loss.loss_utils import nn from transformers.models.qwen2 import modeling_qwen2 from transformers.models.qwen2.modeling_qwen2 import Qwen2Model @@ -595,6 +600,8 @@ def apply_liger_kernel_to_qwen2( if cross_entropy: if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): + from transformers.loss.loss_utils import nn + nn.functional.cross_entropy = liger_cross_entropy else: logger.warning(TRANSFORMER_DEPRECATION_WARNING) @@ -735,7 +742,6 @@ def apply_liger_kernel_to_phi3( cross_entropy and fused_linear_cross_entropy ), "cross_entropy and fused_linear_cross_entropy cannot both be True." - from transformers.loss.loss_utils import nn from transformers.models.phi3 import modeling_phi3 from transformers.models.phi3.modeling_phi3 import Phi3Model @@ -747,6 +753,8 @@ def apply_liger_kernel_to_phi3( modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP if cross_entropy: if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): + from transformers.loss.loss_utils import nn + nn.functional.cross_entropy = liger_cross_entropy else: logger.warning(TRANSFORMER_DEPRECATION_WARNING) diff --git a/test/convergence/test_mini_models_with_logits.py b/test/convergence/test_mini_models_with_logits.py index d142e4903..80eeb5330 100644 --- a/test/convergence/test_mini_models_with_logits.py +++ b/test/convergence/test_mini_models_with_logits.py @@ -586,7 +586,7 @@ def run_mini_model( # ), # ), # TODO: mixtral is flaky so disable the test for now - ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5), + # ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5), # pytest.param( # "mini_mixtral", # 32, diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index 7ce1aacb7..4ccd08dae 100644 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -23,6 +23,25 @@ ) +# Check if optional modules are available +def is_mllama_available(): + try: + import transformers.models.mllama # noqa: F401 + + return True + except ImportError: + return False + + +def is_qwen2_vl_available(): + try: + import transformers.models.qwen2_vl # noqa: F401 + + return True + except ImportError: + return False + + def test_import_from_root(): try: from liger_kernel.transformers import ( # noqa: F401 @@ -250,6 +269,7 @@ def test_apply_liger_kernel_to_instance_for_llama(): ) == inspect.getsource(LigerRMSNorm.forward) +@pytest.mark.skipif(not is_mllama_available(), reason="mllama module not available") def test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation(): # Ensure any monkey patching is cleaned up for subsequent tests with patch("transformers.models.mllama.modeling_mllama"): @@ -363,6 +383,7 @@ def test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation(): ) == inspect.getsource(LigerLayerNorm.forward) +@pytest.mark.skipif(not is_mllama_available(), reason="mllama module not available") def test_apply_liger_kernel_to_instance_for_mllama_for_causal_lm(): # Ensure any monkey patching is cleaned up for subsequent tests with patch("transformers.models.mllama.modeling_mllama"): @@ -676,6 +697,7 @@ def test_apply_liger_kernel_to_instance_for_qwen2(): ) == inspect.getsource(LigerRMSNorm.forward) +@pytest.mark.skipif(not is_qwen2_vl_available(), reason="qwen2_vl module not available") def test_apply_liger_kernel_to_instance_for_qwen2_vl(): # Ensure any monkey patching is cleaned up for subsequent tests with patch("transformers.models.qwen2_vl.modeling_qwen2_vl"): From f7fb0aac8dfd26b6be34687003a75fc5e76c048f Mon Sep 17 00:00:00 2001 From: ByronHsu Date: Tue, 12 Nov 2024 20:14:54 +0000 Subject: [PATCH 4/4] improve ci test name --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 86ec2b581..16d319862 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -61,7 +61,7 @@ jobs: python -m pip install --upgrade pip pip install modal - - name: Run unit tests + - name: Run tests run: | modal run dev.modal.tests @@ -86,6 +86,6 @@ jobs: python -m pip install --upgrade pip pip install modal - - name: Run unit tests + - name: Run tests run: | modal run dev.modal.tests_bwd \ No newline at end of file