-
Notifications
You must be signed in to change notification settings - Fork 231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for patching post-model initialization #199
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome bug catches!
@@ -295,18 +322,30 @@ def apply_liger_kernel_to_phi3( | |||
} | |||
|
|||
|
|||
def _apply_liger_kernel(model_type: str = "", **kwargs) -> None: | |||
def _apply_liger_kernel(model: PreTrainedModel = None, model_type: str = "", **kwargs) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i would prefer separate this as another api. maybe apply_liger_kernel_on_instance
. Would love the thoughts from @yundai424 @qingquansong
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seconded on separating instance-patch to a different API. Otherwise apply_liger_kernel_to_llama(model=) is just a liiitle bit cumbersome because the model itself has told what model family is it 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, I can make it a separate API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be good to still unified as one entry _apply_liger_kernel(model: Union[str, PreTrainedModel])
and can have apply_liger_kernel_on_instance
to be used in side (but feel like not necessarily needed to have a separate instance API? 🤔
9e0857e
to
0e124b3
Compare
d1b748a
to
0fbf513
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To confirm my understanding, both _apply_liger_kernel_to_instance
and _apply_liger_kernel
will be used in the HF api underlying based on different cases and we name it with underline to not recommend for users to directly use it (compared to using single model patch function)? 🤔
@qingquansong Actually the HF Trainer API would be switched to using The previous discussion with @ByronHsu was not to expose these methods publicly hence the leading _, although maybe we should re-export this from transformers/init.py to make internal refactoring less disruptive down the road. |
Some data, patching post init not working with fsdp More things Working monkey patch
265bc0a
to
905b55d
Compare
commit 8e2bd26f67123d37333488206fb8f36614c9567c Author: Chiwan Park <chiwanpark@hotmail.com> Date: Thu Sep 26 00:34:57 2024 +0900 Fix sharing a ResBlock layer for each head in Medusa example (#269) ## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> There is a bug of incorrect weight sharing between layers for each Medusa head. Since `nn.Module` is Python reference, the original source code creates a list containing references to the same weights. This PR fixes the bug. <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: A100-80G-PCIe - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit dcc7c9e4137390f4fc1a522b0cb89849d9b947ce Author: Mark Saroufim <marksaroufim@gmail.com> Date: Tue Sep 24 10:10:41 2024 -0700 rename cuda mode to gpu mode (#267) ## Summary This is a documentation only change since the server has been recently renamed. See this tweet for context https://x.com/jeremyphoward/status/1838341110344880637 Hopefully this is OK to merge :) commit ed4e60cd1d101410f4cffb06f18d62347b5bffc5 Author: Tyler Romero <tyler.alexander.romero@gmail.com> Date: Sun Sep 22 10:01:51 2024 -0700 chore: Add Qwen2.5 and Phi3.5 to Readme (#265) ## Summary Qwen2.5 was released recently - it uses the same model architecture as Qwen2 (see: https://huggingface.co/Qwen/Qwen2.5-7B/blob/main/config.json#L3). Likewise for Phi3.5 (see: https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/config.json#L4). Also, the `model_type`s in the above configs will allow AutoLigerKernelForCausalLM to work correctly for these models out of the box. Adding them to the readme for clarity / marketing reasons :) - Hardware Type: <BLANK> - [ x ] run `make test` to ensure correctness - [ x ] run `make checkstyle` to ensure code style - [ x ] run `make test-convergence` to ensure convergence commit dd86cbd2092177681acf75643ded1b23a785a816 Author: Shivam Sahni <shivam15800@gmail.com> Date: Fri Sep 20 16:30:59 2024 -0700 Update contributing guide for adding a new model (#260) commit 1289cc41c2591df6a2c1e7d902f8733239991100 Author: Steven Shimizu <shimizust@gmail.com> Date: Fri Sep 20 16:17:12 2024 -0700 Fix AutoLigerKernelForCausalLM to pass through original kwargs (#263) ## Summary - Fixes https://github.com/linkedin/Liger-Kernel/issues/250 to correctly pass all original kwargs to .from_pretrained(). Previously we were only passing args that were part of the model config, but there are additional valid kwargs beyond that. - We still need to filter out the kwargs passed into the apply_liger_* functions, or else will result in model init errors ## Testing Done Tested on huggingface example with some of the args in https://github.com/linkedin/Liger-Kernel/issues/250 - Hardware Type: A100 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit ce71d59b0b0894f9f3e7512f5a3bf3780c5a1499 Author: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Thu Sep 19 23:03:41 2024 +0800 Fix a comment typo in flce (#256) ## Summary os huge -> is huge ## Testing Done - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence commit 58fd2bc85073fdb010164426c9b159cd8a0e9542 Author: Hanson Wang <hansonw@users.noreply.github.com> Date: Tue Sep 17 08:48:20 2024 -0700 [Easy] Cast program_id to int64 in SwiGLU/GeGLU kernels (#251) ## Summary I hit some memory corruption errors testing large batches of tokens with larger models - e.g. with Gemma2-27B and a batch size of 80K tokens you will hit 80K * 36864 = 2.949e9 elements in the intermediate dimension, greater than (signed) int32! `tl.program_id` needs to be casted to int64 like in the fused cross-entropy kernel. ## Testing Done I didn't add a unit test for this because it would require a fair bit of VRAM, but can do so if desired. Was able to verify that forward+backward works without corruption on a Gemma2-27B model. - Hardware Type: A100 80GB - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence Co-authored-by: Shao Tang <tangshao28@gmail.com> commit d1343adcdd9314efe004dbfb431cd7ad91105c12 Author: Edoardo Luciani <edoardo.luciani@gmail.com> Date: Sat Sep 14 23:45:29 2024 +0100 Remove debug print statement (#247) ## Summary Just a removal of a debug print statement left by accident (I suppose). - Hardware Type: NVIDIA RTX 2070 - [x] run `make test` to ensure correctness (to the extent my GPU can) - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit 793785f2dc999a2aef78fac58616a6ea93034542 Author: Qingquan Song <ustcsqq@gmail.com> Date: Fri Sep 13 14:18:14 2024 -0700 Release Liger-Kernel version 0.3.0 (#246) Release Liger-Kernel version 0.3.0 ## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence commit 53d5934c796d7ecdcbdf9790dd9fec89a2205149 Author: Shivam Sahni <shivam15800@gmail.com> Date: Fri Sep 13 11:53:07 2024 -0700 Reduction support for CrossEntropy and Division by 0 Fix (#153) commit 7a5d48425d816fac15195f56861787d80659e16b Author: Steven Shimizu <shimizust@gmail.com> Date: Fri Sep 13 10:39:46 2024 -0700 Support for patching post-model initialization (#199) ## Summary - Currently, calling the patching APIs after the model has been initialized will only partially patch with Liger kernels. For example, the following will still be patched: - Model `forward()` method (e.g. `modeling_llama.LlamaForCausalLM.forward = lce_forward`) - module functions (e.g. `modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb`) but not any modules that were already instantiated and set as instance variables on the model: - For example: `modeling_llama.LlamaRMSNorm = LigerRMSNorm` will not affect existing LlamaRMSNorm instances - This means that integrations with HF Trainer and SFTTrainer only partially work. In the case of SFTTrainer, the current integration only works fully if the user passes a path to the model (https://github.com/huggingface/trl/pull/1992), and SFTTrainer handles calling `AutoLigerKernelForCausalLM.from_pretrained`. However, both HF Trainer and SFTTrainer allow passing the model instance directly, in which case we would need a way of patching the model post-init. ### API ```python from liger_kernel.transformers import _apply_liger_kernel_to_instance llama_model = AutoModelForCausalLM.from_pretrained("/path/to/llama", ...) _apply_liger_kernel_to_instance(model=llama_model) # can also pass in model-specific args that will get passed into the correct apply_liger_kernel_to_{model_type} _apply_liger_kernel_to_instance(model=llama_model, rope=False) ``` Required changes post-PR: - Update HF Trainer and SFTTrainer to pass in model instead of model_type to `_apply_liger_kernel(model=model)` ## Testing Done - Tested HF example with no patching, pre-init patching, post-init patching using existing method, post-init patching of model instance variables showing that post-init instance patching results in same performance as pre-init patching: - Added unit tests that model instances are actually patched correctly and that each patching API supports model instance patching. **Llama** | Scenario | Memory Usage (GB) | Throughput (tokens/sec) | |--------------|-------------------|-------------------------| | No patching | 68.3 | 9161 | | Patch post-init (existing method) | 39.9 | 10939 | | Patch pre-init | 38.4 | 12313 | | Patch post-init (instance patching) | 38.3 | 12409 | **Mistral** | Scenario | Memory Usage (GB) | Throughput (tokens/sec) | |--------------|-------------------|-------------------------| | No patching | 34.5 | 10976 | | Patch post-init (existing method) | 34.5 | 10812 | | Patch pre-init | 34.5 | 12286 | | Patch post-init (instance patching) | 34.5 | 12086 | **Mixtral** OOM on testing setup, but confirmed patched post-model init loaded correctly and started training **Gemma** | Scenario | Memory Usage (GB) | Throughput (tokens/sec) | |--------------|-------------------|-------------------------| | No patching | 52.8 | 7199 | | Patch post-init (existing method) | 40.7 | 9479 | | Patch pre-init | 40.7 | 9209 | | Patch post-init (instance patching) | 40.7 | 9981 | **Gemma2** | Scenario | Memory Usage (GB) | Throughput (tokens/sec) | |--------------|-------------------|-------------------------| | No patching | 66.9 | 12256 | | Patch post-init (existing method) | 51.0 | 14288 | | Patch pre-init | 50.8 | 17084 | | Patch post-init (instance patching) | 46.8 | 15893 | * Note: Gemma2 pre-init and post-init (instance patching) are not converging, seeing nan gradnorm. Need to investigate separately. **Qwen2** | Scenario | Memory Usage (GB) | Throughput (tokens/sec) | |--------------|-------------------|-------------------------| | No patching | 41.2 | 10785 | | Patch post-init (existing method) | 36.8 | 10491 | | Patch pre-init | 36.8 | 11908 | | Patch post-init (instance patching) | 36.8 | 12068 | **Qwen2 VL** - Patched, but couldn't test since not yet released in latest transformers **Phi3** | Scenario | Memory Usage (GB) | Throughput (tokens/sec) | |--------------|-------------------|-------------------------| | No patching | 49.4 | 15601 | | Patch post-init (existing method) | 49.4 | 16259 | | Patch pre-init | 42.8 | 20383 | | Patch post-init (instance patching) | 48.6 | 18922 | * Note: For phi3, post-init patching seems to consistently not match pre-init patching for some reason - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit d4879dfe2974d2a4d85019f143c331709f6920f5 Author: Austin Liu <austin362667@gmail.com> Date: Fri Sep 13 13:09:00 2024 +0800 Restore monkey patched modules (#232) ## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Fixes https://github.com/linkedin/Liger-Kernel/issues/176 <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> There are several ways to restore a monkey-patched library in Python, including using context managers, decorators, pytest fixtures, or reloading the entire module. This PR focuses on reverting monkey-patched modules when `with_liger` is disabled in convergence tests. ```python import target.module importlib.reload(target.module) ``` These changes simplify the process of resetting the affected patched library and help prevent unintended side effects. And it's easier than manually reassigning functions anyway. ## Follow-up If this PR resolves the https://github.com/linkedin/Liger-Kernel/issues/176, it might introduce other value mismatch problems. We may need to adjust the convergence tolerance accordingly. For instance, ``` ______________________ test_mini_model[mini_mixtral-32-0.0001-dtype11-1e-08-1e-05-0.1-1e-05-0.01-1e-05] _______________________ model_name = 'mini_mixtral', num_steps = 32, lr = 0.0001, dtype = torch.bfloat16, loss_atol = 1e-08, loss_rtol = 1e-05 logits_atol = 0.1, logits_rtol = 1e-05, param_atol = 0.01, param_rtol = 1e-05 @pytest.mark.parametrize( "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol", [ ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5), pytest.param( "mini_llama3", 32, 1e-4, torch.bfloat16, 5e-3, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_qwen2", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), pytest.param( "mini_qwen2_vl", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5, marks=pytest.mark.skipif( not QWEN2_VL_AVAILABLE, reason="Qwen2-VL not available in this version of transformers", ), ), pytest.param( "mini_qwen2_vl", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5, marks=[ pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), pytest.mark.skipif( not QWEN2_VL_AVAILABLE, reason="Qwen2-VL not available in this version of transformers", ), ], ), ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_phi3", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_mistral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5), pytest.param( "mini_mixtral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way) ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_gemma1", 32, 1e-4, torch.bfloat16, 1e-2, 1e-4, 2e-1, 1e-5, 1e-2, 1e-5, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_gemma1.1", 32, 1e-4, torch.bfloat16, 1e-2, 1e-4, 2e-1, 1e-5, 1e-2, 1e-5, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_gemma2", 32, 1e-4, torch.bfloat16, 1e-2, 1e-4, 2e-1, 1e-5, 1e-2, 1e-5, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), ], ) def test_mini_model( model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol, ): # Non-liger models should be initialized and tested first to avoid the module being overridden expected_output = run_mini_model( model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr ) actual_output = run_mini_model( model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True ) # Compare every step of the loss > assert_verbose_allclose( torch.tensor([expected_output["loss"]]), torch.tensor([actual_output["loss"]]), atol=loss_atol, rtol=loss_rtol, ) test/convergence/test_mini_models_no_logits.py:594: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ tensor1 = tensor([[10.9374, 7.0162, 4.8162, 3.2886, 2.4254, 1.9993, 1.6753, 1.7743, 1.4267, 1.4742, 1.4458, ... 1.0867, 0.8353, 0.9219, 0.8796, 0.8610, 0.8183, 0.7559, 0.8734, 0.9647, 0.7261, 1.0963, 0.8136]]) tensor2 = tensor([[10.9383, 7.0052, 4.8145, 3.3515, 2.3853, 2.0174, 1.6758, 1.7778, 1.4256, 1.4737, 1.4442, ... 1.0870, 0.8346, 0.9222, 0.8817, 0.8610, 0.8181, 0.7554, 0.8736, 0.9671, 0.7263, 1.0966, 0.8171]]) rtol = 1e-05, atol = 1e-08, max_print = 5 def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5): """ Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. Parameters: tensor1 (torch.Tensor): First tensor to compare. tensor2 (torch.Tensor): Second tensor to compare. rtol (float): Relative tolerance. atol (float): Absolute tolerance. max_print (int): Maximum number of mismatched elements to print. Raises: AssertionError: If the tensors are not all close within the given tolerance. """ # Check if the shapes of the tensors match if tensor1.shape != tensor2.shape: raise AssertionError("Input tensors must have the same shape.") # Calculate the difference between the tensors diff = torch.abs(tensor1 - tensor2) # Determine the tolerance tolerance = atol + rtol * torch.abs(tensor2) # Find mismatched elements mismatched = diff > tolerance # Get the indices of mismatched elements mismatched_indices = torch.nonzero(mismatched) # Count the number of mismatched elements num_mismatched = mismatched.sum().item() # Check if all elements are close all_close = num_mismatched == 0 # Raise AssertionError with detailed information if there are mismatches if not all_close and num_mismatched > 1: mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] print_count = min(max_print, num_mismatched) for index in mismatched_indices[:print_count]: i = tuple(index.tolist()) mismatch_details.append( f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}" ) if num_mismatched > max_print: mismatch_details.append( f"... and {num_mismatched - max_print} more mismatched elements." ) > raise AssertionError("\n".join(mismatch_details)) E AssertionError: Number of mismatched elements: 32 E Mismatch at index (0, 0): tensor1[(0, 0)] = 10.937411308288574, tensor2[(0, 0)] = 10.938319206237793 E Mismatch at index (0, 1): tensor1[(0, 1)] = 7.016175270080566, tensor2[(0, 1)] = 7.0052409172058105 E Mismatch at index (0, 2): tensor1[(0, 2)] = 4.8161821365356445, tensor2[(0, 2)] = 4.814478397369385 E Mismatch at index (0, 3): tensor1[(0, 3)] = 3.288573980331421, tensor2[(0, 3)] = 3.3514533042907715 E Mismatch at index (0, 4): tensor1[(0, 4)] = 2.425377368927002, tensor2[(0, 4)] = 2.3853368759155273 E ... and 27 more mismatched elements. test/utils.py:83: AssertionError ---------------------------------------------------- Captured stdout call ----------------------------------------------------- Liger kernel patches have been reverted. Step 0, Loss: 10.937411308288574 Step 1, Loss: 7.016175270080566 Step 2, Loss: 4.8161821365356445 Step 3, Loss: 3.288573980331421 Step 4, Loss: 2.425377368927002 Step 5, Loss: 1.999261736869812 Step 6, Loss: 1.675323486328125 Step 7, Loss: 1.7742501497268677 Step 8, Loss: 1.4266773462295532 Step 9, Loss: 1.474155068397522 Step 10, Loss: 1.4458246231079102 Step 11, Loss: 1.1540931463241577 Step 12, Loss: 1.3520232439041138 Step 13, Loss: 1.311019778251648 Step 14, Loss: 1.219789981842041 Step 15, Loss: 1.3071205615997314 Step 16, Loss: 1.2621395587921143 Step 17, Loss: 1.3119654655456543 Step 18, Loss: 1.1880946159362793 Step 19, Loss: 1.2357648611068726 Step 20, Loss: 1.0867037773132324 Step 21, Loss: 0.8352738618850708 Step 22, Loss: 0.9218576550483704 Step 23, Loss: 0.879619836807251 Step 24, Loss: 0.8610480427742004 Step 25, Loss: 0.8182975053787231 Step 26, Loss: 0.7558884620666504 Step 27, Loss: 0.8734312057495117 Step 28, Loss: 0.9646832942962646 Step 29, Loss: 0.7261283993721008 Step 30, Loss: 1.0963469743728638 Step 31, Loss: 0.8136419057846069 Step 0, Loss: 10.938319206237793 Step 1, Loss: 7.0052409172058105 Step 2, Loss: 4.814478397369385 Step 3, Loss: 3.3514533042907715 Step 4, Loss: 2.3853368759155273 Step 5, Loss: 2.0173795223236084 Step 6, Loss: 1.6758073568344116 Step 7, Loss: 1.777788519859314 Step 8, Loss: 1.4255633354187012 Step 9, Loss: 1.4737187623977661 Step 10, Loss: 1.4441752433776855 Step 11, Loss: 1.1313129663467407 Step 12, Loss: 1.3452619314193726 Step 13, Loss: 1.299330234527588 Step 14, Loss: 1.2130300998687744 Step 15, Loss: 1.3027563095092773 Step 16, Loss: 1.2582926750183105 Step 17, Loss: 1.3112103939056396 Step 18, Loss: 1.1886006593704224 Step 19, Loss: 1.235780954360962 Step 20, Loss: 1.0869864225387573 Step 21, Loss: 0.8346381187438965 Step 22, Loss: 0.9222478866577148 Step 23, Loss: 0.8816985487937927 Step 24, Loss: 0.8609745502471924 Step 25, Loss: 0.81810462474823 Step 26, Loss: 0.7554237246513367 Step 27, Loss: 0.8736312389373779 Step 28, Loss: 0.967080295085907 Step 29, Loss: 0.7262533903121948 Step 30, Loss: 1.0965538024902344 Step 31, Loss: 0.8171141147613525 =================================================== short test summary info =================================================== FAILED test/convergence/test_mini_models.py::test_mini_model[mini_gemma1-32-0.0001-dtype1-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 27 FAILED test/convergence/test_mini_models.py::test_mini_model[mini_gemma1.1-32-0.0001-dtype3-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 27 FAILED test/convergence/test_mini_models.py::test_mini_model[mini_llama3-32-0.0001-dtype7-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 29 FAILED test/convergence/test_mini_models.py::test_mini_model[mini_mistral-32-0.0001-dtype9-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 27 FAILED test/convergence/test_mini_models.py::test_mini_model[mini_qwen2-32-0.0001-dtype11-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 31 FAILED test/convergence/test_mini_models.py::test_mini_model[mini_phi3-32-0.0001-dtype13-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 31 FAILED test/convergence/test_mini_models_no_logits.py::test_mini_model[mini_qwen2-32-0.0001-dtype3-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 27 FAILED test/convergence/test_mini_models_no_logits.py::test_mini_model[mini_phi3-32-0.0001-dtype7-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 31 FAILED test/convergence/test_mini_models_no_logits.py::test_mini_model[mini_mistral-32-0.0001-dtype9-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 27 FAILED test/convergence/test_mini_models_no_logits.py::test_mini_model[mini_mixtral-32-0.0001-dtype11-1e-08-1e-05-0.1-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 32 ============================== 10 failed, 20 passed, 4 skipped, 4 warnings in 226.37s (0:03:46) =============================== make: *** [Makefile:23: test-convergence] Error 1 ``` ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: Nvidia A100 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Austin Liu <austin362667@gmail.com> Co-authored-by: ByronHsu <byronhsu1230@gmail.com> commit 3d0653b035222cbb845435a1994854e4fd219107 Author: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Thu Sep 12 00:45:39 2024 +0800 Add label smoothing to FLCE and unit tests (#244) ## Summary Fix #243 ## Testing Done - Hardware Type: RTX-3080 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit 83a66d85b9c409ad6f9b17f751886c7936e40290 Author: tastelikefeet <58414341+tastelikefeet@users.noreply.github.com> Date: Wed Sep 11 00:42:09 2024 +0800 SWIFT Trainer Integration (#240) ## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence commit acd82728207ebafad28d448640502c108901a967 Author: Hanson Wang <hansonw@users.noreply.github.com> Date: Mon Sep 9 15:30:09 2024 -0700 Optimize fused_linear_cross_entropy when weight does not require grads (#237) ## Summary Add some easy checks for `weight.requires_grad` to skip allocating + calculating weight gradients if they're not needed. The weight gradient matrix can be pretty large, so this can also be a significant memory savings. Also, a small micro-optimization: skip the `.item()` call on `total_n_non_ignore` (the subsequent calculations work fine with the tensor form) to defer CUDA synchronization (otherwise it will wait for all the `torch.zeros` initializations on the preceding lines to synchronize, which may take a non-trivial amount of time.) ## Testing Done The existing unit test already has a case where the weight does not have gradients enabled, and it still passes forwards/backwards: https://github.com/linkedin/Liger-Kernel/blob/main/test/transformers/test_fused_linear_cross_entropy.py#L165 And the preceding test verifies the 'normal' case where the weight gradients are needed. - Hardware Type: A100 80G - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit b5d8cbf90d338ea2eda4e2e1863dcf0722599197 Author: Tyler Romero <tyler.alexander.romero@gmail.com> Date: Sun Sep 8 14:14:45 2024 -0700 Monkeypatch for Qwen2-VL (#175) ## Summary Monkeypatch for the recently-published [Qwen2-VL](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). HF `transformers` modeling code: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py Feature Request: https://github.com/linkedin/Liger-Kernel/issues/165 ## Details Qwen2-VL in `transformers` is available on `transformers` main but is yet to be published in a release. ## Testing Done - Hardware Type: 4090 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang <tangshao28@gmail.com> commit 9250546513c8549d51f62284610d04077e9589f4 Author: S1ro <54212263+S1ro1@users.noreply.github.com> Date: Sun Sep 8 03:44:04 2024 +0200 Feat: add kl div to readme (#229) ## Summary Adds newly implemented kl divergence loss to readme. Closes #188 finally. ## Testing Done No code changes --------- Co-authored-by: Shao Tang <tangshao28@gmail.com> Co-authored-by: Byron Hsu <byronhsu1230@gmail.com> commit 1cdb7f0d63701065ffb92399ed12f4206f95566b Author: S1ro <54212263+S1ro1@users.noreply.github.com> Date: Sun Sep 8 03:19:19 2024 +0200 Refactor/benchmarking visualizer (#212) ## Summary Implements a new script, `benchmark/benchmarks_visualizer.py`, that substitues the functionality provided by current `benchmark/benchmarks_visualizer.ipynb`. Resolves #211 . ## Details ```console $ python3 benchmarks_visualizer.py --help usage: benchmarks_visualizer.py [-h] --kernel-name KERNEL_NAME --metric-name METRIC_NAME --kernel-operation-mode KERNEL_OPERATION_MODE [--display] [--overwrite] options: -h, --help show this help message and exit --kernel-name KERNEL_NAME Kernel name to benchmark --metric-name METRIC_NAME Metric name to visualize (speed/memory) --kernel-operation-mode KERNEL_OPERATION_MODE Kernel operation mode to visualize (forward/backward/full) --display Display the visualization --overwrite Overwrite existing visualization, if none exist this flag has no effect as one are always created ``` ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang <tangshao28@gmail.com> commit 18fd280b9a5681d489eae5354e14001751e2464f Author: Wizyoung <happyyanghehe@gmail.com> Date: Sun Sep 8 07:56:04 2024 +0800 (fix) fix pyproject.toml (#226) ## Summary In https://github.com/linkedin/Liger-Kernel/pull/218, I fixed the `tool.setuptools.packages.find` field and tested it only in editable mode with `pip install -e .`. However, in production mode with `pip install .`, only the env_report.py file is copied to the Python site-packages directory. To fix this, adding "liger_kernel.*" to the include list will ensure that setuptools correctly includes all subpackages within liger_kernel. <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Co-authored-by: Byron Hsu <byronhsu1230@gmail.com> commit 638b31057d283a0d841a1795f742068a63b7dcdd Author: Wizyoung <happyyanghehe@gmail.com> Date: Sat Sep 7 11:53:15 2024 +0800 add repr infomation for layer_norm and rms_norm (#220) ## Summary Add repr information for layernorm and rmsnorm class so that the useful layer information can be displayed after the model is printed. Other classes are not modified because they inherit from related torch.nn classes, or there are torch.nn sub-modules. ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Byron Hsu <byronhsu1230@gmail.com> Co-authored-by: Shao Tang <tangshao28@gmail.com> commit 07804e43a5e6e019a829c37c9cb022a4c2aa4bed Author: Ivan Yashchuk <IvanYashchuk@users.noreply.github.com> Date: Sat Sep 7 06:30:32 2024 +0300 Update swiglu and geglu forward: zeros_like -> empty_like (#217) ## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> This PR improves the performance of swiglu and geglu forward by replacing `zeros_like` with `empty_like`. The difference is that `empty_like` doesn't require a separate kernel launch. <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> Testing is covered by existing `test_geglu.py` and `test_swiglu.py`. <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: A100-80G-PCIe - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Byron Hsu <byronhsu1230@gmail.com> Co-authored-by: Shao Tang <tangshao28@gmail.com> commit 6a75ddcaf4757003c4424338af85a70b0805db81 Author: Byron Hsu <byronhsu1230@gmail.com> Date: Fri Sep 6 20:16:05 2024 -0700 Update README.md commit 8cf49e2830e44fa4ef845ebf1f9e6d229dbf1aae Author: Byron Hsu <byronhsu1230@gmail.com> Date: Fri Sep 6 20:13:51 2024 -0700 Update README.md commit 53dcf02cd2c1efd8d32a15101755388e401df091 Author: Wizyoung <happyyanghehe@gmail.com> Date: Sat Sep 7 07:13:28 2024 +0800 (fix) fix pyproject.toml (#218) ## Summary Fix `tool.setuptools.packages.find` field in pyproject.toml. Otherwise in local build mode with `pip install .`, python system fails to locate liger_kernel. Co-authored-by: Byron Hsu <byronhsu1230@gmail.com> commit b42a27bd7006e84b01994ae429c6ae47fa3d07b4 Author: Steven Shimizu <shimizust@gmail.com> Date: Fri Sep 6 14:16:41 2024 -0700 Added HF use-case benchmark script (#223) ## Summary - Added Hugging Face training benchmarking script used for tech report - Writes files to `/results/${MODEL_TYPE}_use_liger_${USE_LIGER}_batch_size_${BATCH_SIZE}_rep_${i}.log` ## Testing Done - Ran benchmarking script - Hardware Type: A100 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit 43cbd4e6b250218b2008cf81504b5dc9763ac228 Author: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Sat Sep 7 05:07:01 2024 +0800 Add label smoothing for cross entropy (#198) ## Summary Aim to solve #81. ## Details ### For loss: Label smoothing regularization ( LSR ) by replacing the label distribution $q(k) = \delta_{k,y}$ with ```math q'(k) = (1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K} ``` Considering cross entropy with LSR is ```math \begin{align} L' = H(q', p) &= -\sum^K_{k=1}log\ {p(k)}q'(k) = -\sum^K_{k=1}log\ {p(k)}((1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K})\\ &= -\sum^K_{k=1}log\ {p(k)}(1 - \epsilon)q(k) -\sum^K_{k=1}log\ {p(k)}\frac{\epsilon}{K} \\ &= (1 - \epsilon)H(q,p) + \frac{\epsilon}{K} \sum^K_{k=1} log\ softmax(x_k)\\ &= (1- \epsilon)L + \frac{\epsilon}{K}\ SmoothLoss, \end{align} ``` where $L = H(q,p)$ is the original loss and $\sum^K_{k=1} log\ softmax(x_k)$ is smooth loss. ### For gradients: The original: ```math \begin{align} \frac{\partial L}{\partial x_i} &= p(k) - q(k)\\ &= \begin{cases} softmax(x_i) , & i \neq y \\ softmax(x_i) - 1, & i = y \end{cases} \end{align} ``` With LSR: ```math \begin{align} \frac{\partial L'}{\partial x_i} &= p(k) - q'(k)\\ &= softmax(x_i) - (1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K}\\ &= \begin{cases} softmax(x_i) - \frac{\epsilon}{K}, & i \neq y \\ softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon) & i = y \end{cases} \end{align} ``` We can handle the $i = y$ case by simply adding $-(1-\epsilon)$ after computing all $i$. Reference: [Rethinking the Inception Architecture for Computer Vision](https://arxiv.org/abs/1512.00567) ## Testing Done Add a unit test for label smoothing. - Hardware Type: RTX-3080 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence ```bash ❯ python3 -m pytest test/transformers/test_cross_entropy.py ============================================ test session starts ============================================= platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0 rootdir: /home/tcc/Liger-Kernel collected 94 items test/transformers/test_cross_entropy.py .............................................................. [ 65%] ...............................F [100%] ================================================== FAILURES ================================================== __________________________________ test_large_no_exception[8-16384-128256] ___________________________________ B = 8, T = 16384, V = 128256 @pytest.mark.parametrize( "B, T, V", [ ( 8, 8192, 128256, ), # _input = 16GB, total = ~32GB, 8405385216 > 2,147,483,647, so we need int64 (8, 16384, 128256), # _input = 32GB, total = ~64GB ], ) # @pytest.mark.skipif( # torch.cuda.get_device_properties(0).total_memory < 64 * 1000 * 1000 * 1000, # reason="Needs 64GB+ GPU memory.", # ) def test_large_no_exception(B, T, V): # The large inputs were hitting cuda illegal memory access because of # https://github.com/triton-lang/triton/issues/1058 > _full_pass_once(B, T, V) test/transformers/test_cross_entropy.py:401: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ B = 8, T = 16384, V = 128256 def _full_pass_once(B, T, V): torch.manual_seed(0) liger_ce = LigerCrossEntropyLoss() > _input = torch.randn( B * T, V, requires_grad=True, device="cuda", dtype=torch.bfloat16 ) E torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 31.31 GiB. GPU 0 has a total capacity of 10.00 GiB of which 8.84 GiB is free. Including non-PyTorch memory, this process has 17179869184.00 GiB memory in use. Of the allocated memory 0 bytes is allocated by PyTorch, and 0 bytes is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables) test/transformers/test_cross_entropy.py:374: OutOfMemoryError ========================================== short test summary info =========================================== FAILED test/transformers/test_cross_entropy.py::test_large_no_exception[8-16384-128256] - torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 31.31 GiB. GPU 0 has a total capacity of 10... ================================== 1 failed, 93 passed in 130.88s (0:02:10) ================================== ``` ```bash ❯ make test python -m pytest --disable-warnings test/ --ignore=test/convergence ============================================ test session starts ============================================= platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0 rootdir: /home/tcc/Liger-Kernel collected 256 items test/transformers/test_auto_model.py . [ 0%] test/transformers/test_cross_entropy.py ssssssssssssssssssssssss............ssssssssssssssssssssssssss [ 24%] ssssssssssssssssssssssssssssssss [ 37%] test/transformers/test_embedding.py ........... [ 41%] test/transformers/test_fused_linear_cross_entropy.py ................ [ 47%] test/transformers/test_geglu.py ............ [ 52%] test/transformers/test_layer_norm.py ................ [ 58%] test/transformers/test_monkey_patch.py ..... [ 60%] test/transformers/test_rms_norm.py ............................................................ [ 83%] test/transformers/test_rope.py .................. [ 91%] test/transformers/test_swiglu.py .................... [ 98%] test/transformers/test_trainer_integration.py . [ 99%] test/triton/test_triton_monkey_patch.py .. [100%] ================================ 174 passed, 82 skipped in 123.06s (0:02:03) ================================= ``` ```bash ❯ make checkstyle flake8 .; flake8_status=$?; \ isort .; isort_status=$?; \ black .; black_status=$?; \ if [ $flake8_status -ne 0 ] || [ $isort_status -ne 0 ] || [ $black_status -ne 0 ]; then \ exit 1; \ fi Skipped 2 files All done! ✨ 🍰 ✨ 68 files left unchanged. ``` ```bash ❯ make test-convergence HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence ============================================ test session starts ============================================= platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0 rootdir: /home/tcc/Liger-Kernel collected 30 items test/convergence/test_mini_models.py .............. [ 46%] test/convergence/test_mini_models_no_logits.py ................ [100%] ======================================= 30 passed in 223.18s (0:03:43) ======================================= ``` commit 376fe0c2af65ff4d716dc36eb6fe5231662920a7 Author: Yanning Chen <momochenonline@gmail.com> Date: Fri Sep 6 13:10:02 2024 -0700 Reference Unsloth in header (#216) ## Summary Reference Unsloth in header section <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit c844f787e9828e69cf18016f018bf793d1823ea3 Author: Byron Hsu <byronhsu1230@gmail.com> Date: Fri Sep 6 13:08:22 2024 -0700 Update README.md commit ec68ac0a0725d37d30d22596f1fedf7e67382367 Author: Byron Hsu <byronhsu1230@gmail.com> Date: Fri Sep 6 13:07:18 2024 -0700 Add license in ack section (#224) ## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence commit ec6320096a823b3107c70a81babca1dff6589191 Author: Byron Hsu <byronhsu1230@gmail.com> Date: Fri Sep 6 12:58:33 2024 -0700 Elaborate ack section (#222) ## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
Summary
forward()
method (e.g.modeling_llama.LlamaForCausalLM.forward = lce_forward
)modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
)but not any modules that were already instantiated and set as instance variables on the model:
modeling_llama.LlamaRMSNorm = LigerRMSNorm
will not affect existing LlamaRMSNorm instancesAutoLigerKernelForCausalLM.from_pretrained
. However, both HF Trainer and SFTTrainer allow passing the model instance directly, in which case we would need a way of patching the model post-init.API
Required changes post-PR:
_apply_liger_kernel(model=model)
Testing Done
Llama
Mistral
Mixtral
OOM on testing setup, but confirmed patched post-model init loaded correctly and started training
Gemma
Gemma2
Qwen2
Qwen2 VL
Phi3
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence