Skip to content

Commit

Permalink
Update swiglu and geglu forward: zeros_like -> empty_like (#217)
Browse files Browse the repository at this point in the history
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
This PR improves the performance of swiglu and geglu forward by
replacing `zeros_like` with `empty_like`. The difference is that
`empty_like` doesn't require a separate kernel launch.

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
Testing is covered by existing `test_geglu.py` and `test_swiglu.py`.

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: A100-80G-PCIe
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence

---------

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
Co-authored-by: Shao Tang <tangshao28@gmail.com>
  • Loading branch information
3 people authored Sep 7, 2024
1 parent 6a75ddc commit 07804e4
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/liger_kernel/ops/geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def geglu_forward(a, b):
n_cols = ori_shape[-1]
a = a.view(-1, n_cols)
b = b.view(-1, n_cols)
c = torch.zeros_like(a)
c = torch.empty_like(a)
n_rows = a.shape[0]

BLOCK_SIZE, num_warps = calculate_settings(n_cols)
Expand Down
2 changes: 1 addition & 1 deletion src/liger_kernel/ops/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def swiglu_forward(a, b):
n_cols = ori_shape[-1]
a = a.view(-1, n_cols)
b = b.view(-1, n_cols)
c = torch.zeros_like(a)
c = torch.empty_like(a)
n_rows = a.shape[0]

BLOCK_SIZE, num_warps = calculate_settings(n_cols)
Expand Down

0 comments on commit 07804e4

Please sign in to comment.