Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit acd8272
Author: Hanson Wang <hansonw@users.noreply.github.com>
Date:   Mon Sep 9 15:30:09 2024 -0700

    Optimize fused_linear_cross_entropy when weight does not require grads (linkedin#237)

    ## Summary

    Add some easy checks for `weight.requires_grad` to skip allocating +
    calculating weight gradients if they're not needed. The weight gradient
    matrix can be pretty large, so this can also be a significant memory
    savings.

    Also, a small micro-optimization: skip the `.item()` call on
    `total_n_non_ignore` (the subsequent calculations work fine with the
    tensor form) to defer CUDA synchronization (otherwise it will wait for
    all the `torch.zeros` initializations on the preceding lines to
    synchronize, which may take a non-trivial amount of time.)

    ## Testing Done

    The existing unit test already has a case where the weight does not have
    gradients enabled, and it still passes forwards/backwards:
    https://github.com/linkedin/Liger-Kernel/blob/main/test/transformers/test_fused_linear_cross_entropy.py#L165

    And the preceding test verifies the 'normal' case where the weight
    gradients are needed.

    - Hardware Type: A100 80G
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

commit b5d8cbf
Author: Tyler Romero <tyler.alexander.romero@gmail.com>
Date:   Sun Sep 8 14:14:45 2024 -0700

    Monkeypatch for Qwen2-VL (linkedin#175)

    ## Summary
    Monkeypatch for the recently-published
    [Qwen2-VL](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
    HF `transformers` modeling code:
    https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py

    Feature Request: linkedin#165

    ## Details
    Qwen2-VL in `transformers` is available on `transformers` main but is
    yet to be published in a release.

    ## Testing Done
    - Hardware Type: 4090
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

    ---------

    Co-authored-by: Shao Tang <tangshao28@gmail.com>

commit 9250546
Author: S1ro <54212263+S1ro1@users.noreply.github.com>
Date:   Sun Sep 8 03:44:04 2024 +0200

    Feat: add kl div to readme (linkedin#229)

    ## Summary
    Adds newly implemented kl divergence loss to readme. Closes linkedin#188
    finally.

    ## Testing Done
    No code changes

    ---------

    Co-authored-by: Shao Tang <tangshao28@gmail.com>
    Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

commit 1cdb7f0
Author: S1ro <54212263+S1ro1@users.noreply.github.com>
Date:   Sun Sep 8 03:19:19 2024 +0200

    Refactor/benchmarking visualizer (linkedin#212)

    ## Summary
    Implements a new script, `benchmark/benchmarks_visualizer.py`, that
    substitues the functionality provided by current
    `benchmark/benchmarks_visualizer.ipynb`. Resolves linkedin#211 .

    ## Details
    ```console
    $ python3 benchmarks_visualizer.py --help
    usage: benchmarks_visualizer.py [-h] --kernel-name KERNEL_NAME --metric-name METRIC_NAME --kernel-operation-mode KERNEL_OPERATION_MODE [--display] [--overwrite]

    options:
      -h, --help            show this help message and exit
      --kernel-name KERNEL_NAME
                            Kernel name to benchmark
      --metric-name METRIC_NAME
                            Metric name to visualize (speed/memory)
      --kernel-operation-mode KERNEL_OPERATION_MODE
                            Kernel operation mode to visualize (forward/backward/full)
      --display             Display the visualization
      --overwrite           Overwrite existing visualization, if none exist this flag has no effect as one are always created
      ```

    ## Testing Done
    <!--- This is a required section; please describe how this change was tested. --->

    - Hardware Type: <BLANK>
    - [ ] run `make test` to ensure correctness
    - [ ] run `make checkstyle` to ensure code style
    - [ ] run `make test-convergence` to ensure convergence

    ---------

    Co-authored-by: Shao Tang <tangshao28@gmail.com>

commit 18fd280
Author: Wizyoung <happyyanghehe@gmail.com>
Date:   Sun Sep 8 07:56:04 2024 +0800

    (fix) fix pyproject.toml (linkedin#226)

    ## Summary
    In linkedin#218, I fixed the
    `tool.setuptools.packages.find` field and tested it only in editable
    mode with `pip install -e .`. However, in production mode with `pip
    install .`, only the env_report.py file is copied to the Python
    site-packages directory. To fix this, adding "liger_kernel.*" to the
    include list will ensure that setuptools correctly includes all
    subpackages within liger_kernel.

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: <BLANK>
    - [ ] run `make test` to ensure correctness
    - [ ] run `make checkstyle` to ensure code style
    - [ ] run `make test-convergence` to ensure convergence

    ---------

    Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

commit 638b310
Author: Wizyoung <happyyanghehe@gmail.com>
Date:   Sat Sep 7 11:53:15 2024 +0800

    add repr infomation for layer_norm and rms_norm (linkedin#220)

    ## Summary
    Add repr information for layernorm and rmsnorm class so that the useful
    layer information can be displayed after the model is printed. Other
    classes are not modified because they inherit from related torch.nn
    classes, or there are torch.nn sub-modules.

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: <BLANK>
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

    ---------

    Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
    Co-authored-by: Shao Tang <tangshao28@gmail.com>

commit 07804e4
Author: Ivan Yashchuk <IvanYashchuk@users.noreply.github.com>
Date:   Sat Sep 7 06:30:32 2024 +0300

    Update swiglu and geglu forward: zeros_like -> empty_like (linkedin#217)

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->
    This PR improves the performance of swiglu and geglu forward by
    replacing `zeros_like` with `empty_like`. The difference is that
    `empty_like` doesn't require a separate kernel launch.

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->
    Testing is covered by existing `test_geglu.py` and `test_swiglu.py`.

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: A100-80G-PCIe
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

    ---------

    Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
    Co-authored-by: Shao Tang <tangshao28@gmail.com>

commit 6a75ddc
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Fri Sep 6 20:16:05 2024 -0700

    Update README.md

commit 8cf49e2
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Fri Sep 6 20:13:51 2024 -0700

    Update README.md

commit 53dcf02
Author: Wizyoung <happyyanghehe@gmail.com>
Date:   Sat Sep 7 07:13:28 2024 +0800

    (fix) fix pyproject.toml (linkedin#218)

    ## Summary
    Fix `tool.setuptools.packages.find` field in pyproject.toml. Otherwise
    in local build mode with `pip install .`, python system fails to locate
    liger_kernel.

    Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

commit b42a27b
Author: Steven Shimizu <shimizust@gmail.com>
Date:   Fri Sep 6 14:16:41 2024 -0700

    Added HF use-case benchmark script (linkedin#223)

    ## Summary
    - Added Hugging Face training benchmarking script used for tech report
    - Writes files to
    `/results/${MODEL_TYPE}_use_liger_${USE_LIGER}_batch_size_${BATCH_SIZE}_rep_${i}.log`

    ## Testing Done
    - Ran benchmarking script

    - Hardware Type: A100
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

commit 43cbd4e
Author: Tcc0403 <76503978+Tcc0403@users.noreply.github.com>
Date:   Sat Sep 7 05:07:01 2024 +0800

    Add label smoothing for cross entropy (linkedin#198)

    ## Summary
    Aim to solve linkedin#81.

    ## Details

    ### For loss:
    Label smoothing regularization ( LSR ) by replacing the label
    distribution $q(k) = \delta_{k,y}$ with
    ```math
    q'(k) = (1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K}
    ```
    Considering cross entropy with LSR is

    ```math
    \begin{align}
    L' = H(q', p) &= -\sum^K_{k=1}log\ {p(k)}q'(k) = -\sum^K_{k=1}log\ {p(k)}((1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K})\\
                        &=  -\sum^K_{k=1}log\ {p(k)}(1 - \epsilon)q(k)   -\sum^K_{k=1}log\ {p(k)}\frac{\epsilon}{K} \\
                        &= (1 - \epsilon)H(q,p) + \frac{\epsilon}{K} \sum^K_{k=1} log\ softmax(x_k)\\
                        &= (1- \epsilon)L + \frac{\epsilon}{K}\ SmoothLoss,

    \end{align}
    ```
    where $L = H(q,p)$ is the original loss and $\sum^K_{k=1} log\
    softmax(x_k)$ is smooth loss.

    ### For gradients:
    The original:
    ```math
    \begin{align}
    \frac{\partial L}{\partial x_i} &= p(k) - q(k)\\
                                                &= \begin{cases}
                                                       softmax(x_i) ,                        &  i \neq y \\
                                                       softmax(x_i) - 1,                    &  i = y

    \end{cases}
    \end{align}
    ```
    With LSR:
    ```math
    \begin{align}
    \frac{\partial L'}{\partial x_i} &= p(k) - q'(k)\\
                                                &= softmax(x_i) - (1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K}\\
                                                &= \begin{cases} softmax(x_i) - \frac{\epsilon}{K},                        &  i \neq y \\
                                                       softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon) &  i = y

    \end{cases}
    \end{align}
    ```

    We can handle the $i = y$ case by simply adding $-(1-\epsilon)$ after
    computing all $i$.

    Reference:
    [Rethinking the Inception Architecture for Computer
    Vision](https://arxiv.org/abs/1512.00567)

    ## Testing Done
    Add a unit test for label smoothing.

    - Hardware Type: RTX-3080
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence
    ```bash
    ❯ python3 -m pytest test/transformers/test_cross_entropy.py
    ============================================ test session starts =============================================
    platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0
    rootdir: /home/tcc/Liger-Kernel
    collected 94 items

    test/transformers/test_cross_entropy.py .............................................................. [ 65%]
    ...............................F                                                                       [100%]

    ================================================== FAILURES ==================================================
    __________________________________ test_large_no_exception[8-16384-128256] ___________________________________

    B = 8, T = 16384, V = 128256

        @pytest.mark.parametrize(
            "B, T, V",
            [
                (
                    8,
                    8192,
                    128256,
                ),  # _input = 16GB, total = ~32GB, 8405385216 > 2,147,483,647, so we need int64
                (8, 16384, 128256),  # _input = 32GB, total = ~64GB
            ],
        )
        # @pytest.mark.skipif(
        #     torch.cuda.get_device_properties(0).total_memory < 64 * 1000 * 1000 * 1000,
        #     reason="Needs 64GB+ GPU memory.",
        # )
        def test_large_no_exception(B, T, V):
            # The large inputs were hitting cuda illegal memory access because of
            # triton-lang/triton#1058
    >       _full_pass_once(B, T, V)

    test/transformers/test_cross_entropy.py:401:
    _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    B = 8, T = 16384, V = 128256

        def _full_pass_once(B, T, V):
            torch.manual_seed(0)
            liger_ce = LigerCrossEntropyLoss()

    >       _input = torch.randn(
                B * T, V, requires_grad=True, device="cuda", dtype=torch.bfloat16
            )
    E       torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 31.31 GiB. GPU 0 has a total capacity of 10.00 GiB of which 8.84 GiB is free. Including non-PyTorch memory, this process has 17179869184.00 GiB memory in use. Of the allocated memory 0 bytes is allocated by PyTorch, and 0 bytes is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

    test/transformers/test_cross_entropy.py:374: OutOfMemoryError
    ========================================== short test summary info ===========================================
    FAILED test/transformers/test_cross_entropy.py::test_large_no_exception[8-16384-128256] - torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 31.31 GiB. GPU 0 has a total capacity of 10...
    ================================== 1 failed, 93 passed in 130.88s (0:02:10) ==================================
    ```
    ```bash
    ❯ make test
    python -m pytest --disable-warnings test/ --ignore=test/convergence
    ============================================ test session starts =============================================
    platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0
    rootdir: /home/tcc/Liger-Kernel
    collected 256 items

    test/transformers/test_auto_model.py .                                                                 [  0%]
    test/transformers/test_cross_entropy.py ssssssssssssssssssssssss............ssssssssssssssssssssssssss [ 24%]
    ssssssssssssssssssssssssssssssss                                                                       [ 37%]
    test/transformers/test_embedding.py ...........                                                        [ 41%]
    test/transformers/test_fused_linear_cross_entropy.py ................                                  [ 47%]
    test/transformers/test_geglu.py ............                                                           [ 52%]
    test/transformers/test_layer_norm.py ................                                                  [ 58%]
    test/transformers/test_monkey_patch.py .....                                                           [ 60%]
    test/transformers/test_rms_norm.py ............................................................        [ 83%]
    test/transformers/test_rope.py ..................                                                      [ 91%]
    test/transformers/test_swiglu.py ....................                                                  [ 98%]
    test/transformers/test_trainer_integration.py .                                                        [ 99%]
    test/triton/test_triton_monkey_patch.py ..                                                             [100%]

    ================================ 174 passed, 82 skipped in 123.06s (0:02:03) =================================
    ```
    ```bash
    ❯ make checkstyle
    flake8 .; flake8_status=$?; \
    isort .; isort_status=$?; \
    black .; black_status=$?; \
    if [ $flake8_status -ne 0 ] || [ $isort_status -ne 0 ] || [ $black_status -ne 0 ]; then \
            exit 1; \
    fi
    Skipped 2 files
    All done! ✨ 🍰 ✨
    68 files left unchanged.
    ```
    ```bash
    ❯ make test-convergence
    HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence
    ============================================ test session starts =============================================
    platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0
    rootdir: /home/tcc/Liger-Kernel
    collected 30 items

    test/convergence/test_mini_models.py ..............                                                    [ 46%]
    test/convergence/test_mini_models_no_logits.py ................                                        [100%]

    ======================================= 30 passed in 223.18s (0:03:43) =======================================
    ```

commit 376fe0c
Author: Yanning Chen <momochenonline@gmail.com>
Date:   Fri Sep 6 13:10:02 2024 -0700

    Reference Unsloth in header (linkedin#216)

    ## Summary
    Reference Unsloth in header section

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: <BLANK>
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

commit c844f78
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Fri Sep 6 13:08:22 2024 -0700

    Update README.md

commit ec68ac0
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Fri Sep 6 13:07:18 2024 -0700

    Add license in ack section (linkedin#224)

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: <BLANK>
    - [ ] run `make test` to ensure correctness
    - [ ] run `make checkstyle` to ensure code style
    - [ ] run `make test-convergence` to ensure convergence

commit ec63200
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Fri Sep 6 12:58:33 2024 -0700

    Elaborate ack section (linkedin#222)

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: <BLANK>
    - [ ] run `make test` to ensure correctness
    - [ ] run `make checkstyle` to ensure code style
    - [ ] run `make test-convergence` to ensure convergence
  • Loading branch information
wizyoung committed Sep 10, 2024
1 parent a774cbe commit c67720f
Show file tree
Hide file tree
Showing 11 changed files with 696 additions and 27 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,8 @@ site/
build/
dist/

# Lockfiles
uv.lock

# Benchmark images
benchmark/visualizations
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
- **Exact:** Computation is exact—no approximations! Both forward and backward passes are implemented with rigorous unit tests and undergo convergence testing against training runs without Liger Kernel to ensure accuracy.
- **Lightweight:** Liger Kernel has minimal dependencies, requiring only Torch and Triton—no extra libraries needed! Say goodbye to dependency headaches!
- **Multi-GPU supported:** Compatible with multi-GPU setups (PyTorch FSDP, DeepSpeed, DDP, etc.).
- **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860)
- **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860)

## Target Audiences

Expand Down Expand Up @@ -227,6 +227,7 @@ loss.backward()
| Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss |
| Qwen2 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Qwen2-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Phi3 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |


Expand Down
54 changes: 30 additions & 24 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def fused_linear_cross_entropy_forward(
_input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0, softcap_value=None
_input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0
):
dtype = (
torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else _input.dtype
Expand All @@ -37,13 +37,16 @@ def fused_linear_cross_entropy_forward(
) # (BT + inc_factor - 1) // inc_factor
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size

grad_weight = torch.zeros_like(weight, device=device)
grad_weight = (
torch.zeros_like(weight, device=device) if weight.requires_grad else None
)
grad_input = torch.zeros_like(_input, device=device)
grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
# we use fp32 for loss accumulator
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)

total_n_non_ignore = (target != ignore_index).sum().item()
# NOTE: skip .item() here to avoid CUDA synchronization
total_n_non_ignore = (target != ignore_index).sum()

for chunk_id in range(num_chunks):
start_idx = chunk_id * chunk_size
Expand Down Expand Up @@ -103,14 +106,16 @@ def fused_linear_cross_entropy_forward(
n_non_ignore / total_n_non_ignore
) # chunk_size x V
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
torch.addmm(
input=grad_weight,
mat1=logits_chunk.t(),
mat2=_input_chunk,
out=grad_weight,
alpha=n_non_ignore / total_n_non_ignore,
beta=1.0,
)

if grad_weight is not None:
torch.addmm(
input=grad_weight,
mat1=logits_chunk.t(),
mat2=_input_chunk,
out=grad_weight,
alpha=n_non_ignore / total_n_non_ignore,
beta=1.0,
)

if bias is not None:
torch.add(
Expand Down Expand Up @@ -145,17 +150,18 @@ def fused_linear_cross_entropy_backward(
)

# handle grad_weight
V, H = grad_weight.shape
n_rows = V
if grad_weight is not None:
V, H = grad_weight.shape
n_rows = V

element_mul_kernel[(n_rows,)](
grad_weight,
grad_weight.stride(-2),
grad_output,
H,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
)
element_mul_kernel[(n_rows,)](
grad_weight,
grad_weight.stride(-2),
grad_output,
H,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
)

if grad_bias is not None:
V = grad_bias.shape[0]
Expand All @@ -175,7 +181,7 @@ def fused_linear_cross_entropy_backward(
class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx, _input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0, softcap_value=None
ctx, _input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0
):
"""
Fusing the last linear layer with cross-entropy loss
Expand All @@ -193,12 +199,12 @@ def forward(
ignore_index: the index to ignore in the target
"""
loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
_input, weight, target, bias, ignore_index, label_smoothing, softcap_value
_input, weight, target, bias, ignore_index, label_smoothing
)
# downcast to dtype and store for backward
ctx.save_for_backward(
grad_input.detach(),
grad_weight.detach(),
grad_weight.detach() if grad_weight is not None else None,
grad_bias.detach() if bias is not None else None,
)
return loss
Expand Down
1 change: 1 addition & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
apply_liger_kernel_to_mixtral,
apply_liger_kernel_to_phi3,
apply_liger_kernel_to_qwen2,
apply_liger_kernel_to_qwen2_vl,
)
from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
Expand Down
172 changes: 172 additions & 0 deletions src/liger_kernel/transformers/model/qwen2_vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from typing import List, Optional, Tuple, Union

import torch
from torch.nn import CrossEntropyLoss
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
_CONFIG_FOR_DOC,
QWEN2_VL_INPUTS_DOCSTRING,
Qwen2VLCausalLMOutputWithPast,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)

from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
)


@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
r"""
Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
>>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
>>> messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
```"""

output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)

if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.get_dtype())
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).to(
inputs_embeds.device
)
image_mask = input_ids == self.config.image_token_id
if self.training:
inputs_embeds = inputs_embeds.clone()
inputs_embeds[image_mask] = image_embeds
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw).to(
inputs_embeds.device
)
video_mask = input_ids == self.config.video_token_id
inputs_embeds[video_mask] = video_embeds
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)

outputs = self.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

hidden_states = outputs[0]

loss = None
logits = None

if self.training and (labels is not None):
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

# Flatten tokens
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
shift_labels = shift_labels.view(-1)

lce = LigerFusedLinearCrossEntropyLoss()
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)

if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output

return Qwen2VLCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=rope_deltas,
)
50 changes: 50 additions & 0 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
Expand Down Expand Up @@ -245,6 +246,54 @@ def apply_liger_kernel_to_qwen2(
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP


def apply_liger_kernel_to_qwen2_vl(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
layer_norm: bool = True,
swiglu: bool = True,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
NOTE: Qwen2-VL is not available in transformers<=4.44.2
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
"""
assert not (
cross_entropy and fused_linear_cross_entropy
), "cross_entropy and fused_linear_cross_entropy cannot both be True."

from transformers.models.qwen2_vl import modeling_qwen2_vl

from liger_kernel.transformers.model.qwen2_vl import (
lce_forward as qwen2_vl_lce_forward,
)

# TODO: Support Qwen2-VL's multimodal RoPE implementation

if rms_norm:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
modeling_qwen2_vl.Qwen2RMSNorm = partial(
LigerRMSNorm, init_fn="ones", casting_mode="gemma"
)
if layer_norm:
modeling_qwen2_vl.LayerNorm = LigerLayerNorm
if cross_entropy:
modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
if swiglu:
modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP


def apply_liger_kernel_to_phi3(
rope: bool = True,
cross_entropy: bool = False,
Expand Down Expand Up @@ -291,6 +340,7 @@ def apply_liger_kernel_to_phi3(
"mistral": apply_liger_kernel_to_mistral,
"mixtral": apply_liger_kernel_to_mixtral,
"qwen2": apply_liger_kernel_to_qwen2,
"qwen2_vl": apply_liger_kernel_to_qwen2_vl,
"phi3": apply_liger_kernel_to_phi3,
}

Expand Down
Loading

0 comments on commit c67720f

Please sign in to comment.