Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CE after grad acc fix #375

Merged
merged 4 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,31 @@ jobs:
python -m pip install --upgrade pip
pip install modal

- name: Run unit tests
- name: Run tests
run: |
modal run dev.modal.tests

tests-bwd:
runs-on: ubuntu-latest
needs: [checkstyle]
env:
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}

steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.10'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install modal

- name: Run tests
run: |
modal run dev.modal.tests_bwd
5 changes: 3 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ checkstyle:
# Command to run pytest for convergence tests
# We have to explicitly set HF_DATASETS_OFFLINE=1, or dataset will silently try to send metrics and timeout (80s) https://github.com/huggingface/datasets/blob/37a603679f451826cfafd8aae00738b01dcb9d58/src/datasets/load.py#L286
test-convergence:
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence

HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models.py
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models_multimodal.py
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models_with_logits.py

# Command to run all benchmark scripts and update benchmarking data file
# By default this doesn't overwrite existing data for the same benchmark experiment
Expand Down
28 changes: 28 additions & 0 deletions dev/modal/tests_bwd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from pathlib import Path

import modal

ROOT_PATH = Path(__file__).parent.parent.parent

# tests_bwd is to ensure the backward compatibility of liger with older transformers
image = (
modal.Image.debian_slim()
.pip_install_from_pyproject(
ROOT_PATH / "pyproject.toml", optional_dependencies=["dev"]
)
.pip_install("transformers==4.44.2")
)

app = modal.App("liger_tests", image=image)

# mount: add local files to the remote container
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path="/root/liger-kernel")


@app.function(gpu="A10G", mounts=[repo], timeout=60 * 10)
def liger_tests():
import subprocess

subprocess.run(["pip", "install", "-e", "."], check=True, cwd="/root/liger-kernel")
subprocess.run(["make", "test"], check=True, cwd="/root/liger-kernel")
subprocess.run(["make", "test-convergence"], check=True, cwd="/root/liger-kernel")
33 changes: 32 additions & 1 deletion src/liger_kernel/transformers/functional.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
from liger_kernel.ops.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyFunction,
Expand All @@ -13,7 +15,6 @@
from liger_kernel.ops.swiglu import LigerSiLUMulFunction

liger_swiglu = LigerSiLUMulFunction.apply
liger_cross_entropy = LigerCrossEntropyFunction.apply
liger_fused_linear_cross_entropy = LigerFusedLinearCrossEntropyFunction.apply
liger_geglu = LigerGELUMulFunction.apply
liger_rms_norm = LigerRMSNormFunction.apply
Expand All @@ -23,3 +24,33 @@
liger_jsd = LigerJSDFunction.apply
liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
liger_group_norm = LigerGroupNormFunction.apply


# conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
# `weight` and `size_average` are placeholders and not implemented yet
def liger_cross_entropy(
input,
target,
weight=None,
size_average=None,
ignore_index: int = -100,
reduce=None,
reduction: str = "mean",
label_smoothing: float = 0.0,
lse_square_scale: float = 0.0,
softcap: Optional[float] = None,
return_z_loss: bool = False,
):
loss, z_loss = LigerCrossEntropyFunction.apply(
input,
target,
ignore_index,
lse_square_scale,
label_smoothing,
reduction,
softcap,
return_z_loss,
)
if not return_z_loss:
return loss
return loss, z_loss
62 changes: 55 additions & 7 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from transformers import PreTrainedModel

from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
Expand Down Expand Up @@ -111,8 +112,16 @@ def apply_liger_kernel_to_llama(
modeling_llama.LlamaRMSNorm = LigerRMSNorm
if swiglu:
modeling_llama.LlamaMLP = LigerSwiGLUMLP

if cross_entropy:
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
else:
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss

if fused_linear_cross_entropy:
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
Expand Down Expand Up @@ -192,7 +201,13 @@ def apply_liger_kernel_to_mllama(
if swiglu:
modeling_mllama.MllamaTextMLP = LigerSwiGLUMLP
if cross_entropy:
modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
else:
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
Expand Down Expand Up @@ -342,7 +357,14 @@ def apply_liger_kernel_to_mixtral(
if rms_norm:
modeling_mixtral.MixtralRMSNorm = LigerRMSNorm
if cross_entropy:
modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
else:
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss

if fused_linear_cross_entropy:
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
Expand Down Expand Up @@ -417,7 +439,13 @@ def apply_liger_kernel_to_gemma(
if rms_norm:
modeling_gemma.GemmaRMSNorm = LigerRMSNormForGemma
if cross_entropy:
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
else:
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
if geglu:
modeling_gemma.GemmaMLP = LigerGEGLUMLP
if fused_linear_cross_entropy:
Expand Down Expand Up @@ -474,6 +502,7 @@ def apply_liger_kernel_to_gemma2(
assert not (
cross_entropy and fused_linear_cross_entropy
), "cross_entropy and fused_linear_cross_entropy cannot both be True."

from transformers.models.gemma2 import modeling_gemma2
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model

Expand All @@ -490,7 +519,13 @@ def apply_liger_kernel_to_gemma2(
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2
if cross_entropy:
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
else:
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
Expand Down Expand Up @@ -562,8 +597,15 @@ def apply_liger_kernel_to_qwen2(
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm

if cross_entropy:
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
else:
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss

# import pdb; pdb.set_trace()
if fused_linear_cross_entropy:
Expand Down Expand Up @@ -710,7 +752,13 @@ def apply_liger_kernel_to_phi3(
if swiglu:
modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
if cross_entropy:
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
else:
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
Expand Down
Loading
Loading