Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch compiled FLCE is 2x faster than the current FLCE #227

Open
ByronHsu opened this issue Sep 7, 2024 · 12 comments
Open

Torch compiled FLCE is 2x faster than the current FLCE #227

ByronHsu opened this issue Sep 7, 2024 · 12 comments

Comments

@ByronHsu
Copy link
Collaborator

ByronHsu commented Sep 7, 2024

🚀 The feature, motivation and pitch

We can leverage torch compile to fuse the things we cannot fuse now like upcasting, contiguous call, etc

image

Sample code: https://gist.github.com/Chillee/22cd93e11b887db1f596ab754d60a899#file-lce_benchmark-py
Provided by the brilliant @Chillee

Alternatives

No response

Additional context

No response

@wizyoung
Copy link
Contributor

Actually, if we align the CHUNK_SIZE of the Torch-compiled FLCE with the strategy used in Liger's FLCE, the compiled version is only slightly faster than the Liger version, but it does require a bit more memory as well. The advantage of the Torch-compiled version is its flexibility; implementing the Gemma2 softcap logits is very straightforward, whereas I struggled for some time to achieve consistent accuracy with this in Liger.

@wizyoung
Copy link
Contributor

By setting CHUNK_SIZE and run benchmark script:
image
image

@Chillee
Copy link

Chillee commented Sep 10, 2024

@wizyoung how are you setting the chunk size? I wasn't able to get the liger kernel to perform much better even when changing the chunk size.

@wizyoung
Copy link
Contributor

wizyoung commented Sep 10, 2024

@Chillee By referencing https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/fused_linear_cross_entropy.py#L23. I mean, chaning chunk size in torch compiled FLCE. Your default chunk size is 1024, and I change to 256. Then I have:
image
By only keep benchmark test with liger and compiled chunkce:
image
image

env: torch2.3.1, triton2.3.1, A100 80G, cuda12.3

@wizyoung
Copy link
Contributor

I have done some quick tests with different B, T, D and V to mimic my training conditions(llama3 and gemma2) in my env, my conclusion is that torch compiled flce is indeed faster, but has worse memory management.

@wizyoung
Copy link
Contributor

https://gist.github.com/wizyoung/5330ad501e73a97dfe2f0088decdb1ca
I have implemented a version of torch.compile chunked_lce that supports soft caps and passes all numerical accuracy tests in benchmark_fused_linear_cross_entropy.py modified from this repo. My main concern is the frequent changes in input shape, which result in varying chunk sizes. To mitigate this overhead, I used torch.compile(dynamic=True, options={"shape_padding": True}). However, I am still uncertain about its effectiveness and look into it during actual training.

@Chillee
Copy link

Chillee commented Sep 10, 2024

@wizyoung I agree there's some additional memory overhead (in particular, I think we don't inplace the addmm), but the additional memory is generally pretty negligible here, no?

For example, if I change the chunk size from 256 to 512, torch.compile performance improves from 186ms down to 153, while memory only increases from 1.48 GB to 1.54 GB.

If I try increasing the chunk size of Liger, it doesn't seem to increase the performance as much as the torch.compile version

@ekojsalim
Copy link

Curious how this compares with JonasGeiping/linear_cross_entropy_loss , but torch.compile seems good enough though.

@Chillee
Copy link

Chillee commented Sep 11, 2024

@ekojsalim In my brief testing, it seems like it's both faster and uses less memory.

@wizyoung
Copy link
Contributor

@wizyoung I agree there's some additional memory overhead (in particular, I think we don't inplace the addmm), but the additional memory is generally pretty negligible here, no?

For example, if I change the chunk size from 256 to 512, torch.compile performance improves from 186ms down to 153, while memory only increases from 1.48 GB to 1.54 GB.

If I try increasing the chunk size of Liger, it doesn't seem to increase the performance as much as the torch.compile version

Yes, the increase in memory usage is generally negligible. My primary concern is the running time overhead, specifically that the B*T varies significantly and is not a multiple of the chunk size, leading to frequent calls of recompile(I add TORCH_LOGS="recompiles" to find that). Therefore, I use torch.compile(dynamic=True, options={"shape_padding": True}) as documented; however, I am uncertain about its actual effectiveness.
I did an expensive benchmark test using the script of this repo by setting BT = [2**12] + (np.random.randint(2**12, 2**15 + 1, 80) + np.random.randint(0, 1001, 80)).tolist() + [2**15] and keeping H=4096 and V=128256. And chunk_size is 1024.
image
image

@Chillee
Copy link

Chillee commented Sep 11, 2024

@wizyoung Can you post your benchmark script?

@wizyoung
Copy link
Contributor

@Chillee I have updated my scripts here: https://gist.github.com/wizyoung/5330ad501e73a97dfe2f0088decdb1ca

ByronHsu pushed a commit that referenced this issue Nov 14, 2024
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
Adds chunked ORPO loss kernel 
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
Benchmarks
![Speed
ORPO](https://github.com/user-attachments/assets/ae9e6f67-14cd-4189-9d64-9a2f94a3b3c6)
![Mem
ORPO](https://github.com/user-attachments/assets/47c289f4-2876-4530-949c-2c2825bc0f79)

References:
1. #227 
2.
https://gist.github.com/Chillee/22cd93e11b887db1f596ab754d60a899#file-lce_benchmark-py
<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence

---------

Co-authored-by: shisahni_LinkedIn <shisahni@linkedin.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants