Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for patching post-model initialization #199

Merged
merged 17 commits into from
Sep 13, 2024

Conversation

shimizust
Copy link
Collaborator

@shimizust shimizust commented Sep 3, 2024

Summary

  • Currently, calling the patching APIs after the model has been initialized will only partially patch with Liger kernels. For example, the following will still be patched:
    • Model forward() method (e.g. modeling_llama.LlamaForCausalLM.forward = lce_forward)
    • module functions (e.g. modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb)
      but not any modules that were already instantiated and set as instance variables on the model:
    • For example: modeling_llama.LlamaRMSNorm = LigerRMSNorm will not affect existing LlamaRMSNorm instances
  • This means that integrations with HF Trainer and SFTTrainer only partially work. In the case of SFTTrainer, the current integration only works fully if the user passes a path to the model (Adds experimental Liger support to SFT script huggingface/trl#1992), and SFTTrainer handles calling AutoLigerKernelForCausalLM.from_pretrained. However, both HF Trainer and SFTTrainer allow passing the model instance directly, in which case we would need a way of patching the model post-init.

API

from liger_kernel.transformers import _apply_liger_kernel_to_instance

llama_model = AutoModelForCausalLM.from_pretrained("/path/to/llama", ...)
_apply_liger_kernel_to_instance(model=llama_model)

# can also pass in model-specific args that will get passed into the correct apply_liger_kernel_to_{model_type}
_apply_liger_kernel_to_instance(model=llama_model, rope=False)

Required changes post-PR:

  • Update HF Trainer and SFTTrainer to pass in model instead of model_type to _apply_liger_kernel(model=model)

Testing Done

  • Tested HF example with no patching, pre-init patching, post-init patching using existing method, post-init patching of model instance variables showing that post-init instance patching results in same performance as pre-init patching:
  • Added unit tests that model instances are actually patched correctly and that each patching API supports model instance patching.

Llama

Scenario Memory Usage (GB) Throughput (tokens/sec)
No patching 68.3 9161
Patch post-init (existing method) 39.9 10939
Patch pre-init 38.4 12313
Patch post-init (instance patching) 38.3 12409

Mistral

Scenario Memory Usage (GB) Throughput (tokens/sec)
No patching 34.5 10976
Patch post-init (existing method) 34.5 10812
Patch pre-init 34.5 12286
Patch post-init (instance patching) 34.5 12086

Mixtral
OOM on testing setup, but confirmed patched post-model init loaded correctly and started training

Gemma

Scenario Memory Usage (GB) Throughput (tokens/sec)
No patching 52.8 7199
Patch post-init (existing method) 40.7 9479
Patch pre-init 40.7 9209
Patch post-init (instance patching) 40.7 9981

Gemma2

Scenario Memory Usage (GB) Throughput (tokens/sec)
No patching 66.9 12256
Patch post-init (existing method) 51.0 14288
Patch pre-init 50.8 17084
Patch post-init (instance patching) 46.8 15893
  • Note: Gemma2 pre-init and post-init (instance patching) are not converging, seeing nan gradnorm. Need to investigate separately.

Qwen2

Scenario Memory Usage (GB) Throughput (tokens/sec)
No patching 41.2 10785
Patch post-init (existing method) 36.8 10491
Patch pre-init 36.8 11908
Patch post-init (instance patching) 36.8 12068

Qwen2 VL

  • Patched, but couldn't test since not yet released in latest transformers

Phi3

Scenario Memory Usage (GB) Throughput (tokens/sec)
No patching 49.4 15601
Patch post-init (existing method) 49.4 16259
Patch pre-init 42.8 20383
Patch post-init (instance patching) 48.6 18922
  • Note: For phi3, post-init patching seems to consistently not match pre-init patching for some reason
  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Copy link
Collaborator

@ByronHsu ByronHsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome bug catches!

@@ -295,18 +322,30 @@ def apply_liger_kernel_to_phi3(
}


def _apply_liger_kernel(model_type: str = "", **kwargs) -> None:
def _apply_liger_kernel(model: PreTrainedModel = None, model_type: str = "", **kwargs) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would prefer separate this as another api. maybe apply_liger_kernel_on_instance. Would love the thoughts from @yundai424 @qingquansong

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seconded on separating instance-patch to a different API. Otherwise apply_liger_kernel_to_llama(model=) is just a liiitle bit cumbersome because the model itself has told what model family is it 🤔

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I can make it a separate API

Copy link
Collaborator

@qingquansong qingquansong Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be good to still unified as one entry _apply_liger_kernel(model: Union[str, PreTrainedModel]) and can have apply_liger_kernel_on_instance to be used in side (but feel like not necessarily needed to have a separate instance API? 🤔

@shimizust shimizust force-pushed the sshimizu/post-init-patch branch from 9e0857e to 0e124b3 Compare September 6, 2024 06:42
@ryankert01 ryankert01 mentioned this pull request Sep 8, 2024
@shimizust shimizust force-pushed the sshimizu/post-init-patch branch from d1b748a to 0fbf513 Compare September 10, 2024 22:29
@shimizust shimizust changed the title [WIP] Support for patching post-model initialization Support for patching post-model initialization Sep 12, 2024
@shimizust shimizust marked this pull request as ready for review September 12, 2024 15:47
Copy link
Collaborator

@qingquansong qingquansong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To confirm my understanding, both _apply_liger_kernel_to_instance and _apply_liger_kernel will be used in the HF api underlying based on different cases and we name it with underline to not recommend for users to directly use it (compared to using single model patch function)? 🤔

@shimizust
Copy link
Collaborator Author

To confirm my understanding, both _apply_liger_kernel_to_instance and _apply_liger_kernel will be used in the HF api underlying based on different cases and we name it with underline to not recommend for users to directly use it (compared to using single model patch function)? 🤔

@qingquansong Actually the HF Trainer API would be switched to using _apply_liger_kernel_to_instance. The _apply_liger_kernel is now meant to only be used internally by AutoLigerKernelForCausalLM.

The previous discussion with @ByronHsu was not to expose these methods publicly hence the leading _, although maybe we should re-export this from transformers/init.py to make internal refactoring less disruptive down the road.

@shimizust shimizust force-pushed the sshimizu/post-init-patch branch from 265bc0a to 905b55d Compare September 13, 2024 16:49
@lancerts lancerts merged commit 7a5d484 into linkedin:main Sep 13, 2024
2 checks passed
wizyoung added a commit to wizyoung/Liger-Kernel that referenced this pull request Sep 26, 2024
commit 8e2bd26f67123d37333488206fb8f36614c9567c
Author: Chiwan Park <chiwanpark@hotmail.com>
Date:   Thu Sep 26 00:34:57 2024 +0900

    Fix sharing a ResBlock layer for each head in Medusa example (#269)

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->
    There is a bug of incorrect weight sharing between layers for each
    Medusa head. Since `nn.Module` is Python reference, the original source
    code creates a list containing references to the same weights. This PR
    fixes the bug.

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: A100-80G-PCIe
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

commit dcc7c9e4137390f4fc1a522b0cb89849d9b947ce
Author: Mark Saroufim <marksaroufim@gmail.com>
Date:   Tue Sep 24 10:10:41 2024 -0700

    rename cuda mode to gpu mode (#267)

    ## Summary

    This is a documentation only change since the server has been recently
    renamed. See this tweet for context
    https://x.com/jeremyphoward/status/1838341110344880637

    Hopefully this is OK to merge :)

commit ed4e60cd1d101410f4cffb06f18d62347b5bffc5
Author: Tyler Romero <tyler.alexander.romero@gmail.com>
Date:   Sun Sep 22 10:01:51 2024 -0700

    chore: Add Qwen2.5 and Phi3.5 to Readme (#265)

    ## Summary
    Qwen2.5 was released recently - it uses the same model architecture as
    Qwen2 (see:
    https://huggingface.co/Qwen/Qwen2.5-7B/blob/main/config.json#L3).
    Likewise for Phi3.5 (see:
    https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/config.json#L4).
    Also, the `model_type`s in the above configs will allow
    AutoLigerKernelForCausalLM to work correctly for these models out of the
    box.

    Adding them to the readme for clarity / marketing reasons :)

    - Hardware Type: <BLANK>
    - [ x ] run `make test` to ensure correctness
    - [ x ] run `make checkstyle` to ensure code style
    - [ x ] run `make test-convergence` to ensure convergence

commit dd86cbd2092177681acf75643ded1b23a785a816
Author: Shivam Sahni <shivam15800@gmail.com>
Date:   Fri Sep 20 16:30:59 2024 -0700

    Update contributing guide for adding a new model (#260)

commit 1289cc41c2591df6a2c1e7d902f8733239991100
Author: Steven Shimizu <shimizust@gmail.com>
Date:   Fri Sep 20 16:17:12 2024 -0700

    Fix AutoLigerKernelForCausalLM to pass through original kwargs (#263)

    ## Summary
    - Fixes https://github.com/linkedin/Liger-Kernel/issues/250 to correctly
    pass all original kwargs to .from_pretrained(). Previously we were only
    passing args that were part of the model config, but there are
    additional valid kwargs beyond that.
    - We still need to filter out the kwargs passed into the apply_liger_*
    functions, or else will result in model init errors

    ## Testing Done
    Tested on huggingface example with some of the args in
    https://github.com/linkedin/Liger-Kernel/issues/250

    - Hardware Type: A100
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

commit ce71d59b0b0894f9f3e7512f5a3bf3780c5a1499
Author: Tcc0403 <76503978+Tcc0403@users.noreply.github.com>
Date:   Thu Sep 19 23:03:41 2024 +0800

    Fix a comment typo in flce (#256)

    ## Summary
    os huge -> is huge

    ## Testing Done

    - Hardware Type: <BLANK>
    - [ ] run `make test` to ensure correctness
    - [ ] run `make checkstyle` to ensure code style
    - [ ] run `make test-convergence` to ensure convergence

commit 58fd2bc85073fdb010164426c9b159cd8a0e9542
Author: Hanson Wang <hansonw@users.noreply.github.com>
Date:   Tue Sep 17 08:48:20 2024 -0700

    [Easy] Cast program_id to int64 in SwiGLU/GeGLU kernels (#251)

    ## Summary

    I hit some memory corruption errors testing large batches of tokens with
    larger models - e.g. with Gemma2-27B and a batch size of 80K tokens you
    will hit 80K * 36864 = 2.949e9 elements in the intermediate dimension,
    greater than (signed) int32!

    `tl.program_id` needs to be casted to int64 like in the fused
    cross-entropy kernel.

    ## Testing Done

    I didn't add a unit test for this because it would require a fair bit of
    VRAM, but can do so if desired.
    Was able to verify that forward+backward works without corruption on a
    Gemma2-27B model.

    - Hardware Type: A100 80GB
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

    Co-authored-by: Shao Tang <tangshao28@gmail.com>

commit d1343adcdd9314efe004dbfb431cd7ad91105c12
Author: Edoardo Luciani <edoardo.luciani@gmail.com>
Date:   Sat Sep 14 23:45:29 2024 +0100

    Remove debug print statement (#247)

    ## Summary
    Just a removal of a debug print statement left by accident (I suppose).

    - Hardware Type: NVIDIA RTX 2070
    - [x] run `make test` to ensure correctness (to the extent my GPU can)
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

commit 793785f2dc999a2aef78fac58616a6ea93034542
Author: Qingquan Song <ustcsqq@gmail.com>
Date:   Fri Sep 13 14:18:14 2024 -0700

    Release Liger-Kernel version 0.3.0 (#246)

    Release Liger-Kernel version 0.3.0

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: <BLANK>
    - [ ] run `make test` to ensure correctness
    - [ ] run `make checkstyle` to ensure code style
    - [ ] run `make test-convergence` to ensure convergence

commit 53d5934c796d7ecdcbdf9790dd9fec89a2205149
Author: Shivam Sahni <shivam15800@gmail.com>
Date:   Fri Sep 13 11:53:07 2024 -0700

    Reduction support for CrossEntropy and Division by 0 Fix (#153)

commit 7a5d48425d816fac15195f56861787d80659e16b
Author: Steven Shimizu <shimizust@gmail.com>
Date:   Fri Sep 13 10:39:46 2024 -0700

    Support for patching post-model initialization (#199)

    ## Summary
    - Currently, calling the patching APIs after the model has been
    initialized will only partially patch with Liger kernels. For example,
    the following will still be patched:
    - Model `forward()` method (e.g.
    `modeling_llama.LlamaForCausalLM.forward = lce_forward`)
    - module functions (e.g. `modeling_llama.apply_rotary_pos_emb =
    liger_rotary_pos_emb`)
    but not any modules that were already instantiated and set as instance
    variables on the model:
    - For example: `modeling_llama.LlamaRMSNorm = LigerRMSNorm` will not
    affect existing LlamaRMSNorm instances
    - This means that integrations with HF Trainer and SFTTrainer only
    partially work. In the case of SFTTrainer, the current integration only
    works fully if the user passes a path to the model
    (https://github.com/huggingface/trl/pull/1992), and SFTTrainer handles
    calling `AutoLigerKernelForCausalLM.from_pretrained`. However, both HF
    Trainer and SFTTrainer allow passing the model instance directly, in
    which case we would need a way of patching the model post-init.

    ### API
    ```python
    from liger_kernel.transformers import _apply_liger_kernel_to_instance

    llama_model = AutoModelForCausalLM.from_pretrained("/path/to/llama", ...)
    _apply_liger_kernel_to_instance(model=llama_model)

    # can also pass in model-specific args that will get passed into the correct apply_liger_kernel_to_{model_type}
    _apply_liger_kernel_to_instance(model=llama_model, rope=False)
    ```

    Required changes post-PR:
    - Update HF Trainer and SFTTrainer to pass in model instead of
    model_type to `_apply_liger_kernel(model=model)`

    ## Testing Done
    - Tested HF example with no patching, pre-init patching, post-init
    patching using existing method, post-init patching of model instance
    variables showing that post-init instance patching results in same
    performance as pre-init patching:
    - Added unit tests that model instances are actually patched correctly
    and that each patching API supports model instance patching.

    **Llama**
    | Scenario     | Memory Usage (GB) | Throughput (tokens/sec) |
    |--------------|-------------------|-------------------------|
    | No patching   | 68.3                | 9161                   |
    | Patch post-init (existing method) | 39.9 | 10939 |
    | Patch pre-init   | 38.4                | 12313                   |
    | Patch post-init (instance patching) | 38.3 | 12409 |

    **Mistral**
    | Scenario     | Memory Usage (GB) | Throughput (tokens/sec) |
    |--------------|-------------------|-------------------------|
    | No patching   | 34.5                | 10976                   |
    | Patch post-init (existing method) | 34.5 | 10812 |
    | Patch pre-init   | 34.5                | 12286                   |
    | Patch post-init (instance patching) | 34.5 | 12086 |

    **Mixtral**
    OOM on testing setup, but confirmed patched post-model init loaded
    correctly and started training

    **Gemma**
    | Scenario     | Memory Usage (GB) | Throughput (tokens/sec) |
    |--------------|-------------------|-------------------------|
    | No patching   | 52.8                | 7199                   |
    | Patch post-init (existing method) | 40.7 | 9479 |
    | Patch pre-init   | 40.7                | 9209                   |
    | Patch post-init (instance patching) | 40.7 | 9981 |

    **Gemma2**
    | Scenario     | Memory Usage (GB) | Throughput (tokens/sec) |
    |--------------|-------------------|-------------------------|
    | No patching   |  66.9               |   12256                 |
    | Patch post-init (existing method) | 51.0 | 14288 |
    | Patch pre-init   | 50.8                |  17084                 |
    | Patch post-init (instance patching) | 46.8 | 15893 |

    * Note: Gemma2 pre-init and post-init (instance patching) are not
    converging, seeing nan gradnorm. Need to investigate separately.

    **Qwen2**
    | Scenario     | Memory Usage (GB) | Throughput (tokens/sec) |
    |--------------|-------------------|-------------------------|
    | No patching   | 41.2                | 10785                   |
    | Patch post-init (existing method) | 36.8 | 10491 |
    | Patch pre-init   | 36.8                | 11908                   |
    | Patch post-init (instance patching) | 36.8 | 12068 |

    **Qwen2 VL**
    - Patched, but couldn't test since not yet released in latest
    transformers

    **Phi3**
    | Scenario     | Memory Usage (GB) | Throughput (tokens/sec) |
    |--------------|-------------------|-------------------------|
    | No patching   | 49.4                | 15601                   |
    | Patch post-init (existing method) | 49.4 | 16259 |
    | Patch pre-init   | 42.8                | 20383                   |
    | Patch post-init (instance patching) | 48.6 | 18922 |

    * Note: For phi3, post-init patching seems to consistently not match
    pre-init patching for some reason

    - Hardware Type: <BLANK>
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

commit d4879dfe2974d2a4d85019f143c331709f6920f5
Author: Austin Liu <austin362667@gmail.com>
Date:   Fri Sep 13 13:09:00 2024 +0800

    Restore monkey patched modules (#232)

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->

    Fixes https://github.com/linkedin/Liger-Kernel/issues/176

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    There are several ways to restore a monkey-patched library in Python,
    including using context managers, decorators, pytest fixtures, or
    reloading the entire module.

    This PR focuses on reverting monkey-patched modules when `with_liger` is
    disabled in convergence tests.
    ```python
    import target.module
    importlib.reload(target.module)
    ```
    These changes simplify the process of resetting the affected patched
    library and help prevent unintended side effects. And it's easier than
    manually reassigning functions anyway.

    ## Follow-up

    If this PR resolves the
    https://github.com/linkedin/Liger-Kernel/issues/176, it might introduce
    other value mismatch problems. We may need to adjust the convergence
    tolerance accordingly. For instance,
    ```
    ______________________ test_mini_model[mini_mixtral-32-0.0001-dtype11-1e-08-1e-05-0.1-1e-05-0.01-1e-05] _______________________

    model_name = 'mini_mixtral', num_steps = 32, lr = 0.0001, dtype = torch.bfloat16, loss_atol = 1e-08, loss_rtol = 1e-05
    logits_atol = 0.1, logits_rtol = 1e-05, param_atol = 0.01, param_rtol = 1e-05

        @pytest.mark.parametrize(
            "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
            [
                ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
                pytest.param(
                    "mini_llama3",
                    32,
                    1e-4,
                    torch.bfloat16,
                    5e-3,
                    1e-5,
                    1e-2,
                    1e-5,
                    1e-2,
                    1e-5,
                    marks=pytest.mark.skipif(
                        not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                    ),
                ),
                ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
                pytest.param(
                    "mini_qwen2",
                    32,
                    1e-4,
                    torch.bfloat16,
                    1e-8,
                    1e-5,
                    1e-2,
                    1e-5,
                    1e-2,
                    1e-5,
                    marks=pytest.mark.skipif(
                        not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                    ),
                ),
                pytest.param(
                    "mini_qwen2_vl",
                    32,
                    1e-4,
                    torch.float32,
                    1e-8,
                    1e-5,
                    5e-3,
                    1e-5,
                    5e-3,
                    1e-5,
                    marks=pytest.mark.skipif(
                        not QWEN2_VL_AVAILABLE,
                        reason="Qwen2-VL not available in this version of transformers",
                    ),
                ),
                pytest.param(
                    "mini_qwen2_vl",
                    32,
                    1e-4,
                    torch.bfloat16,
                    1e-8,
                    1e-5,
                    1e-2,
                    1e-5,
                    1e-2,
                    1e-5,
                    marks=[
                        pytest.mark.skipif(
                            not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                        ),
                        pytest.mark.skipif(
                            not QWEN2_VL_AVAILABLE,
                            reason="Qwen2-VL not available in this version of transformers",
                        ),
                    ],
                ),
                ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
                pytest.param(
                    "mini_phi3",
                    32,
                    1e-4,
                    torch.bfloat16,
                    1e-8,
                    1e-5,
                    1e-2,
                    1e-5,
                    1e-2,
                    1e-5,
                    marks=pytest.mark.skipif(
                        not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                    ),
                ),
                ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
                pytest.param(
                    "mini_mistral",
                    32,
                    1e-4,
                    torch.bfloat16,
                    1e-8,
                    1e-5,
                    1e-2,
                    1e-5,
                    1e-2,
                    1e-5,
                    marks=pytest.mark.skipif(
                        not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                    ),
                ),
                ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
                pytest.param(
                    "mini_mixtral",
                    32,
                    1e-4,
                    torch.bfloat16,
                    1e-8,
                    1e-5,
                    1e-1,
                    1e-5,
                    1e-2,
                    1e-5,
                    marks=pytest.mark.skipif(
                        not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                    ),
                ),
                # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
                ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
                pytest.param(
                    "mini_gemma1",
                    32,
                    1e-4,
                    torch.bfloat16,
                    1e-2,
                    1e-4,
                    2e-1,
                    1e-5,
                    1e-2,
                    1e-5,
                    marks=pytest.mark.skipif(
                        not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                    ),
                ),
                ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
                pytest.param(
                    "mini_gemma1.1",
                    32,
                    1e-4,
                    torch.bfloat16,
                    1e-2,
                    1e-4,
                    2e-1,
                    1e-5,
                    1e-2,
                    1e-5,
                    marks=pytest.mark.skipif(
                        not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                    ),
                ),
                ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
                pytest.param(
                    "mini_gemma2",
                    32,
                    1e-4,
                    torch.bfloat16,
                    1e-2,
                    1e-4,
                    2e-1,
                    1e-5,
                    1e-2,
                    1e-5,
                    marks=pytest.mark.skipif(
                        not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                    ),
                ),
            ],
        )
        def test_mini_model(
            model_name,
            num_steps,
            lr,
            dtype,
            loss_atol,
            loss_rtol,
            logits_atol,
            logits_rtol,
            param_atol,
            param_rtol,
        ):
            # Non-liger models should be initialized and tested first to avoid the module being overridden

            expected_output = run_mini_model(
                model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr
            )

            actual_output = run_mini_model(
                model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True
            )

            # Compare every step of the loss
    >       assert_verbose_allclose(
                torch.tensor([expected_output["loss"]]),
                torch.tensor([actual_output["loss"]]),
                atol=loss_atol,
                rtol=loss_rtol,
            )

    test/convergence/test_mini_models_no_logits.py:594:
    _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    tensor1 = tensor([[10.9374,  7.0162,  4.8162,  3.2886,  2.4254,  1.9993,  1.6753,  1.7743,
              1.4267,  1.4742,  1.4458,  ...  1.0867,  0.8353,  0.9219,  0.8796,
              0.8610,  0.8183,  0.7559,  0.8734,  0.9647,  0.7261,  1.0963,  0.8136]])
    tensor2 = tensor([[10.9383,  7.0052,  4.8145,  3.3515,  2.3853,  2.0174,  1.6758,  1.7778,
              1.4256,  1.4737,  1.4442,  ...  1.0870,  0.8346,  0.9222,  0.8817,
              0.8610,  0.8181,  0.7554,  0.8736,  0.9671,  0.7263,  1.0966,  0.8171]])
    rtol = 1e-05, atol = 1e-08, max_print = 5

        def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
            """
            Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.

            Parameters:
            tensor1 (torch.Tensor): First tensor to compare.
            tensor2 (torch.Tensor): Second tensor to compare.
            rtol (float): Relative tolerance.
            atol (float): Absolute tolerance.
            max_print (int): Maximum number of mismatched elements to print.

            Raises:
            AssertionError: If the tensors are not all close within the given tolerance.
            """
            # Check if the shapes of the tensors match
            if tensor1.shape != tensor2.shape:
                raise AssertionError("Input tensors must have the same shape.")

            # Calculate the difference between the tensors
            diff = torch.abs(tensor1 - tensor2)

            # Determine the tolerance
            tolerance = atol + rtol * torch.abs(tensor2)

            # Find mismatched elements
            mismatched = diff > tolerance

            # Get the indices of mismatched elements
            mismatched_indices = torch.nonzero(mismatched)

            # Count the number of mismatched elements
            num_mismatched = mismatched.sum().item()

            # Check if all elements are close
            all_close = num_mismatched == 0

            # Raise AssertionError with detailed information if there are mismatches
            if not all_close and num_mismatched > 1:
                mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
                print_count = min(max_print, num_mismatched)
                for index in mismatched_indices[:print_count]:
                    i = tuple(index.tolist())
                    mismatch_details.append(
                        f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}"
                    )
                if num_mismatched > max_print:
                    mismatch_details.append(
                        f"... and {num_mismatched - max_print} more mismatched elements."
                    )

    >           raise AssertionError("\n".join(mismatch_details))
    E           AssertionError: Number of mismatched elements: 32
    E           Mismatch at index (0, 0): tensor1[(0, 0)] = 10.937411308288574, tensor2[(0, 0)] = 10.938319206237793
    E           Mismatch at index (0, 1): tensor1[(0, 1)] = 7.016175270080566, tensor2[(0, 1)] = 7.0052409172058105
    E           Mismatch at index (0, 2): tensor1[(0, 2)] = 4.8161821365356445, tensor2[(0, 2)] = 4.814478397369385
    E           Mismatch at index (0, 3): tensor1[(0, 3)] = 3.288573980331421, tensor2[(0, 3)] = 3.3514533042907715
    E           Mismatch at index (0, 4): tensor1[(0, 4)] = 2.425377368927002, tensor2[(0, 4)] = 2.3853368759155273
    E           ... and 27 more mismatched elements.

    test/utils.py:83: AssertionError
    ---------------------------------------------------- Captured stdout call -----------------------------------------------------
    Liger kernel patches have been reverted.
    Step 0, Loss: 10.937411308288574
    Step 1, Loss: 7.016175270080566
    Step 2, Loss: 4.8161821365356445
    Step 3, Loss: 3.288573980331421
    Step 4, Loss: 2.425377368927002
    Step 5, Loss: 1.999261736869812
    Step 6, Loss: 1.675323486328125
    Step 7, Loss: 1.7742501497268677
    Step 8, Loss: 1.4266773462295532
    Step 9, Loss: 1.474155068397522
    Step 10, Loss: 1.4458246231079102
    Step 11, Loss: 1.1540931463241577
    Step 12, Loss: 1.3520232439041138
    Step 13, Loss: 1.311019778251648
    Step 14, Loss: 1.219789981842041
    Step 15, Loss: 1.3071205615997314
    Step 16, Loss: 1.2621395587921143
    Step 17, Loss: 1.3119654655456543
    Step 18, Loss: 1.1880946159362793
    Step 19, Loss: 1.2357648611068726
    Step 20, Loss: 1.0867037773132324
    Step 21, Loss: 0.8352738618850708
    Step 22, Loss: 0.9218576550483704
    Step 23, Loss: 0.879619836807251
    Step 24, Loss: 0.8610480427742004
    Step 25, Loss: 0.8182975053787231
    Step 26, Loss: 0.7558884620666504
    Step 27, Loss: 0.8734312057495117
    Step 28, Loss: 0.9646832942962646
    Step 29, Loss: 0.7261283993721008
    Step 30, Loss: 1.0963469743728638
    Step 31, Loss: 0.8136419057846069
    Step 0, Loss: 10.938319206237793
    Step 1, Loss: 7.0052409172058105
    Step 2, Loss: 4.814478397369385
    Step 3, Loss: 3.3514533042907715
    Step 4, Loss: 2.3853368759155273
    Step 5, Loss: 2.0173795223236084
    Step 6, Loss: 1.6758073568344116
    Step 7, Loss: 1.777788519859314
    Step 8, Loss: 1.4255633354187012
    Step 9, Loss: 1.4737187623977661
    Step 10, Loss: 1.4441752433776855
    Step 11, Loss: 1.1313129663467407
    Step 12, Loss: 1.3452619314193726
    Step 13, Loss: 1.299330234527588
    Step 14, Loss: 1.2130300998687744
    Step 15, Loss: 1.3027563095092773
    Step 16, Loss: 1.2582926750183105
    Step 17, Loss: 1.3112103939056396
    Step 18, Loss: 1.1886006593704224
    Step 19, Loss: 1.235780954360962
    Step 20, Loss: 1.0869864225387573
    Step 21, Loss: 0.8346381187438965
    Step 22, Loss: 0.9222478866577148
    Step 23, Loss: 0.8816985487937927
    Step 24, Loss: 0.8609745502471924
    Step 25, Loss: 0.81810462474823
    Step 26, Loss: 0.7554237246513367
    Step 27, Loss: 0.8736312389373779
    Step 28, Loss: 0.967080295085907
    Step 29, Loss: 0.7262533903121948
    Step 30, Loss: 1.0965538024902344
    Step 31, Loss: 0.8171141147613525
    =================================================== short test summary info ===================================================
    FAILED test/convergence/test_mini_models.py::test_mini_model[mini_gemma1-32-0.0001-dtype1-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 27
    FAILED test/convergence/test_mini_models.py::test_mini_model[mini_gemma1.1-32-0.0001-dtype3-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 27
    FAILED test/convergence/test_mini_models.py::test_mini_model[mini_llama3-32-0.0001-dtype7-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 29
    FAILED test/convergence/test_mini_models.py::test_mini_model[mini_mistral-32-0.0001-dtype9-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 27
    FAILED test/convergence/test_mini_models.py::test_mini_model[mini_qwen2-32-0.0001-dtype11-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 31
    FAILED test/convergence/test_mini_models.py::test_mini_model[mini_phi3-32-0.0001-dtype13-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 31
    FAILED test/convergence/test_mini_models_no_logits.py::test_mini_model[mini_qwen2-32-0.0001-dtype3-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 27
    FAILED test/convergence/test_mini_models_no_logits.py::test_mini_model[mini_phi3-32-0.0001-dtype7-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 31
    FAILED test/convergence/test_mini_models_no_logits.py::test_mini_model[mini_mistral-32-0.0001-dtype9-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 27
    FAILED test/convergence/test_mini_models_no_logits.py::test_mini_model[mini_mixtral-32-0.0001-dtype11-1e-08-1e-05-0.1-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 32
    ============================== 10 failed, 20 passed, 4 skipped, 4 warnings in 226.37s (0:03:46) ===============================
    make: *** [Makefile:23: test-convergence] Error 1
    ```

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: Nvidia A100
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [ ] run `make test-convergence` to ensure convergence

    ---------

    Signed-off-by: Austin Liu <austin362667@gmail.com>
    Co-authored-by: ByronHsu <byronhsu1230@gmail.com>

commit 3d0653b035222cbb845435a1994854e4fd219107
Author: Tcc0403 <76503978+Tcc0403@users.noreply.github.com>
Date:   Thu Sep 12 00:45:39 2024 +0800

    Add label smoothing to FLCE and unit tests (#244)

    ## Summary
    Fix #243

    ## Testing Done

    - Hardware Type: RTX-3080
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

commit 83a66d85b9c409ad6f9b17f751886c7936e40290
Author: tastelikefeet <58414341+tastelikefeet@users.noreply.github.com>
Date:   Wed Sep 11 00:42:09 2024 +0800

    SWIFT Trainer Integration (#240)

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: <BLANK>
    - [ ] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [ ] run `make test-convergence` to ensure convergence

commit acd82728207ebafad28d448640502c108901a967
Author: Hanson Wang <hansonw@users.noreply.github.com>
Date:   Mon Sep 9 15:30:09 2024 -0700

    Optimize fused_linear_cross_entropy when weight does not require grads (#237)

    ## Summary

    Add some easy checks for `weight.requires_grad` to skip allocating +
    calculating weight gradients if they're not needed. The weight gradient
    matrix can be pretty large, so this can also be a significant memory
    savings.

    Also, a small micro-optimization: skip the `.item()` call on
    `total_n_non_ignore` (the subsequent calculations work fine with the
    tensor form) to defer CUDA synchronization (otherwise it will wait for
    all the `torch.zeros` initializations on the preceding lines to
    synchronize, which may take a non-trivial amount of time.)

    ## Testing Done

    The existing unit test already has a case where the weight does not have
    gradients enabled, and it still passes forwards/backwards:
    https://github.com/linkedin/Liger-Kernel/blob/main/test/transformers/test_fused_linear_cross_entropy.py#L165

    And the preceding test verifies the 'normal' case where the weight
    gradients are needed.

    - Hardware Type: A100 80G
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

commit b5d8cbf90d338ea2eda4e2e1863dcf0722599197
Author: Tyler Romero <tyler.alexander.romero@gmail.com>
Date:   Sun Sep 8 14:14:45 2024 -0700

    Monkeypatch for Qwen2-VL (#175)

    ## Summary
    Monkeypatch for the recently-published
    [Qwen2-VL](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
    HF `transformers` modeling code:
    https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py

    Feature Request: https://github.com/linkedin/Liger-Kernel/issues/165

    ## Details
    Qwen2-VL in `transformers` is available on `transformers` main but is
    yet to be published in a release.

    ## Testing Done
    - Hardware Type: 4090
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

    ---------

    Co-authored-by: Shao Tang <tangshao28@gmail.com>

commit 9250546513c8549d51f62284610d04077e9589f4
Author: S1ro <54212263+S1ro1@users.noreply.github.com>
Date:   Sun Sep 8 03:44:04 2024 +0200

    Feat: add kl div to readme (#229)

    ## Summary
    Adds newly implemented kl divergence loss to readme. Closes #188
    finally.

    ## Testing Done
    No code changes

    ---------

    Co-authored-by: Shao Tang <tangshao28@gmail.com>
    Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

commit 1cdb7f0d63701065ffb92399ed12f4206f95566b
Author: S1ro <54212263+S1ro1@users.noreply.github.com>
Date:   Sun Sep 8 03:19:19 2024 +0200

    Refactor/benchmarking visualizer (#212)

    ## Summary
    Implements a new script, `benchmark/benchmarks_visualizer.py`, that
    substitues the functionality provided by current
    `benchmark/benchmarks_visualizer.ipynb`. Resolves #211 .

    ## Details
    ```console
    $ python3 benchmarks_visualizer.py --help
    usage: benchmarks_visualizer.py [-h] --kernel-name KERNEL_NAME --metric-name METRIC_NAME --kernel-operation-mode KERNEL_OPERATION_MODE [--display] [--overwrite]

    options:
      -h, --help            show this help message and exit
      --kernel-name KERNEL_NAME
                            Kernel name to benchmark
      --metric-name METRIC_NAME
                            Metric name to visualize (speed/memory)
      --kernel-operation-mode KERNEL_OPERATION_MODE
                            Kernel operation mode to visualize (forward/backward/full)
      --display             Display the visualization
      --overwrite           Overwrite existing visualization, if none exist this flag has no effect as one are always created
      ```

    ## Testing Done
    <!--- This is a required section; please describe how this change was tested. --->

    - Hardware Type: <BLANK>
    - [ ] run `make test` to ensure correctness
    - [ ] run `make checkstyle` to ensure code style
    - [ ] run `make test-convergence` to ensure convergence

    ---------

    Co-authored-by: Shao Tang <tangshao28@gmail.com>

commit 18fd280b9a5681d489eae5354e14001751e2464f
Author: Wizyoung <happyyanghehe@gmail.com>
Date:   Sun Sep 8 07:56:04 2024 +0800

    (fix) fix pyproject.toml (#226)

    ## Summary
    In https://github.com/linkedin/Liger-Kernel/pull/218, I fixed the
    `tool.setuptools.packages.find` field and tested it only in editable
    mode with `pip install -e .`. However, in production mode with `pip
    install .`, only the env_report.py file is copied to the Python
    site-packages directory. To fix this, adding "liger_kernel.*" to the
    include list will ensure that setuptools correctly includes all
    subpackages within liger_kernel.

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: <BLANK>
    - [ ] run `make test` to ensure correctness
    - [ ] run `make checkstyle` to ensure code style
    - [ ] run `make test-convergence` to ensure convergence

    ---------

    Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

commit 638b31057d283a0d841a1795f742068a63b7dcdd
Author: Wizyoung <happyyanghehe@gmail.com>
Date:   Sat Sep 7 11:53:15 2024 +0800

    add repr infomation for layer_norm and rms_norm (#220)

    ## Summary
    Add repr information for layernorm and rmsnorm class so that the useful
    layer information can be displayed after the model is printed. Other
    classes are not modified because they inherit from related torch.nn
    classes, or there are torch.nn sub-modules.

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: <BLANK>
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

    ---------

    Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
    Co-authored-by: Shao Tang <tangshao28@gmail.com>

commit 07804e43a5e6e019a829c37c9cb022a4c2aa4bed
Author: Ivan Yashchuk <IvanYashchuk@users.noreply.github.com>
Date:   Sat Sep 7 06:30:32 2024 +0300

    Update swiglu and geglu forward: zeros_like -> empty_like (#217)

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->
    This PR improves the performance of swiglu and geglu forward by
    replacing `zeros_like` with `empty_like`. The difference is that
    `empty_like` doesn't require a separate kernel launch.

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->
    Testing is covered by existing `test_geglu.py` and `test_swiglu.py`.

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: A100-80G-PCIe
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

    ---------

    Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
    Co-authored-by: Shao Tang <tangshao28@gmail.com>

commit 6a75ddcaf4757003c4424338af85a70b0805db81
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Fri Sep 6 20:16:05 2024 -0700

    Update README.md

commit 8cf49e2830e44fa4ef845ebf1f9e6d229dbf1aae
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Fri Sep 6 20:13:51 2024 -0700

    Update README.md

commit 53dcf02cd2c1efd8d32a15101755388e401df091
Author: Wizyoung <happyyanghehe@gmail.com>
Date:   Sat Sep 7 07:13:28 2024 +0800

    (fix) fix pyproject.toml (#218)

    ## Summary
    Fix `tool.setuptools.packages.find` field in pyproject.toml. Otherwise
    in local build mode with `pip install .`, python system fails to locate
    liger_kernel.

    Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

commit b42a27bd7006e84b01994ae429c6ae47fa3d07b4
Author: Steven Shimizu <shimizust@gmail.com>
Date:   Fri Sep 6 14:16:41 2024 -0700

    Added HF use-case benchmark script (#223)

    ## Summary
    - Added Hugging Face training benchmarking script used for tech report
    - Writes files to
    `/results/${MODEL_TYPE}_use_liger_${USE_LIGER}_batch_size_${BATCH_SIZE}_rep_${i}.log`

    ## Testing Done
    - Ran benchmarking script

    - Hardware Type: A100
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

commit 43cbd4e6b250218b2008cf81504b5dc9763ac228
Author: Tcc0403 <76503978+Tcc0403@users.noreply.github.com>
Date:   Sat Sep 7 05:07:01 2024 +0800

    Add label smoothing for cross entropy (#198)

    ## Summary
    Aim to solve #81.

    ## Details

    ### For loss:
    Label smoothing regularization ( LSR ) by replacing the label
    distribution $q(k) = \delta_{k,y}$ with
    ```math
    q'(k) = (1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K}
    ```
    Considering cross entropy with LSR is

    ```math
    \begin{align}
    L' = H(q', p) &= -\sum^K_{k=1}log\ {p(k)}q'(k) = -\sum^K_{k=1}log\ {p(k)}((1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K})\\
                        &=  -\sum^K_{k=1}log\ {p(k)}(1 - \epsilon)q(k)   -\sum^K_{k=1}log\ {p(k)}\frac{\epsilon}{K} \\
                        &= (1 - \epsilon)H(q,p) + \frac{\epsilon}{K} \sum^K_{k=1} log\ softmax(x_k)\\
                        &= (1- \epsilon)L + \frac{\epsilon}{K}\ SmoothLoss,

    \end{align}
    ```
    where $L = H(q,p)$ is the original loss and $\sum^K_{k=1} log\
    softmax(x_k)$ is smooth loss.

    ### For gradients:
    The original:
    ```math
    \begin{align}
    \frac{\partial L}{\partial x_i} &= p(k) - q(k)\\
                                                &= \begin{cases}
                                                       softmax(x_i) ,                        &  i \neq y \\
                                                       softmax(x_i) - 1,                    &  i = y

    \end{cases}
    \end{align}
    ```
    With LSR:
    ```math
    \begin{align}
    \frac{\partial L'}{\partial x_i} &= p(k) - q'(k)\\
                                                &= softmax(x_i) - (1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K}\\
                                                &= \begin{cases} softmax(x_i) - \frac{\epsilon}{K},                        &  i \neq y \\
                                                       softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon) &  i = y

    \end{cases}
    \end{align}
    ```

    We can handle the $i = y$ case by simply adding $-(1-\epsilon)$ after
    computing all $i$.

    Reference:
    [Rethinking the Inception Architecture for Computer
    Vision](https://arxiv.org/abs/1512.00567)

    ## Testing Done
    Add a unit test for label smoothing.

    - Hardware Type: RTX-3080
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence
    ```bash
    ❯ python3 -m pytest test/transformers/test_cross_entropy.py
    ============================================ test session starts =============================================
    platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0
    rootdir: /home/tcc/Liger-Kernel
    collected 94 items

    test/transformers/test_cross_entropy.py .............................................................. [ 65%]
    ...............................F                                                                       [100%]

    ================================================== FAILURES ==================================================
    __________________________________ test_large_no_exception[8-16384-128256] ___________________________________

    B = 8, T = 16384, V = 128256

        @pytest.mark.parametrize(
            "B, T, V",
            [
                (
                    8,
                    8192,
                    128256,
                ),  # _input = 16GB, total = ~32GB, 8405385216 > 2,147,483,647, so we need int64
                (8, 16384, 128256),  # _input = 32GB, total = ~64GB
            ],
        )
        # @pytest.mark.skipif(
        #     torch.cuda.get_device_properties(0).total_memory < 64 * 1000 * 1000 * 1000,
        #     reason="Needs 64GB+ GPU memory.",
        # )
        def test_large_no_exception(B, T, V):
            # The large inputs were hitting cuda illegal memory access because of
            # https://github.com/triton-lang/triton/issues/1058
    >       _full_pass_once(B, T, V)

    test/transformers/test_cross_entropy.py:401:
    _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    B = 8, T = 16384, V = 128256

        def _full_pass_once(B, T, V):
            torch.manual_seed(0)
            liger_ce = LigerCrossEntropyLoss()

    >       _input = torch.randn(
                B * T, V, requires_grad=True, device="cuda", dtype=torch.bfloat16
            )
    E       torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 31.31 GiB. GPU 0 has a total capacity of 10.00 GiB of which 8.84 GiB is free. Including non-PyTorch memory, this process has 17179869184.00 GiB memory in use. Of the allocated memory 0 bytes is allocated by PyTorch, and 0 bytes is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

    test/transformers/test_cross_entropy.py:374: OutOfMemoryError
    ========================================== short test summary info ===========================================
    FAILED test/transformers/test_cross_entropy.py::test_large_no_exception[8-16384-128256] - torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 31.31 GiB. GPU 0 has a total capacity of 10...
    ================================== 1 failed, 93 passed in 130.88s (0:02:10) ==================================
    ```
    ```bash
    ❯ make test
    python -m pytest --disable-warnings test/ --ignore=test/convergence
    ============================================ test session starts =============================================
    platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0
    rootdir: /home/tcc/Liger-Kernel
    collected 256 items

    test/transformers/test_auto_model.py .                                                                 [  0%]
    test/transformers/test_cross_entropy.py ssssssssssssssssssssssss............ssssssssssssssssssssssssss [ 24%]
    ssssssssssssssssssssssssssssssss                                                                       [ 37%]
    test/transformers/test_embedding.py ...........                                                        [ 41%]
    test/transformers/test_fused_linear_cross_entropy.py ................                                  [ 47%]
    test/transformers/test_geglu.py ............                                                           [ 52%]
    test/transformers/test_layer_norm.py ................                                                  [ 58%]
    test/transformers/test_monkey_patch.py .....                                                           [ 60%]
    test/transformers/test_rms_norm.py ............................................................        [ 83%]
    test/transformers/test_rope.py ..................                                                      [ 91%]
    test/transformers/test_swiglu.py ....................                                                  [ 98%]
    test/transformers/test_trainer_integration.py .                                                        [ 99%]
    test/triton/test_triton_monkey_patch.py ..                                                             [100%]

    ================================ 174 passed, 82 skipped in 123.06s (0:02:03) =================================
    ```
    ```bash
    ❯ make checkstyle
    flake8 .; flake8_status=$?; \
    isort .; isort_status=$?; \
    black .; black_status=$?; \
    if [ $flake8_status -ne 0 ] || [ $isort_status -ne 0 ] || [ $black_status -ne 0 ]; then \
            exit 1; \
    fi
    Skipped 2 files
    All done! ✨ 🍰 ✨
    68 files left unchanged.
    ```
    ```bash
    ❯ make test-convergence
    HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence
    ============================================ test session starts =============================================
    platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0
    rootdir: /home/tcc/Liger-Kernel
    collected 30 items

    test/convergence/test_mini_models.py ..............                                                    [ 46%]
    test/convergence/test_mini_models_no_logits.py ................                                        [100%]

    ======================================= 30 passed in 223.18s (0:03:43) =======================================
    ```

commit 376fe0c2af65ff4d716dc36eb6fe5231662920a7
Author: Yanning Chen <momochenonline@gmail.com>
Date:   Fri Sep 6 13:10:02 2024 -0700

    Reference Unsloth in header (#216)

    ## Summary
    Reference Unsloth in header section

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: <BLANK>
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

commit c844f787e9828e69cf18016f018bf793d1823ea3
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Fri Sep 6 13:08:22 2024 -0700

    Update README.md

commit ec68ac0a0725d37d30d22596f1fedf7e67382367
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Fri Sep 6 13:07:18 2024 -0700

    Add license in ack section (#224)

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: <BLANK>
    - [ ] run `make test` to ensure correctness
    - [ ] run `make checkstyle` to ensure code style
    - [ ] run `make test-convergence` to ensure convergence

commit ec6320096a823b3107c70a81babca1dff6589191
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Fri Sep 6 12:58:33 2024 -0700

    Elaborate ack section (#222)

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: <BLANK>
    - [ ] run `make test` to ensure correctness
    - [ ] run `make checkstyle` to ensure code style
    - [ ] run `make test-convergence` to ensure convergence
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants