Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add xpu device support for rms_norm #379

Merged
merged 7 commits into from
Nov 19, 2024
Merged

Conversation

faaany
Copy link
Collaborator

@faaany faaany commented Nov 14, 2024

Summary

I was running a trl unit test with liger support (link) and found that cuda device is hard-coded in rms_norm_backward.

This PR adds support for Intel GPU. After the fix, the test passes:

====================================================== short test summary info ======================================================
PASSED tests/slow/test_sft_slow.py::SFTTrainerSlowTester::test_sft_trainer_with_liger_0_trl_internal_testing_tiny_random_LlamaForCausalLM
================================================== 1 passed, 8 warnings in 14.47s ===================================================

Testing Done

A lot of tests fail because it only support CUDA devices.

@ByronHsu
Copy link
Collaborator

cool! what it takes to fix all kernels and make all tests pass?

@faaany
Copy link
Collaborator Author

faaany commented Nov 14, 2024

cool! what it takes to fix all kernels and make all tests pass?

then we will need to make all tests device-agnostic. Currently, I see that the device is hard-coded in the test itself, e.g. https://github.com/linkedin/Liger-Kernel/blob/main/test/transformers/test_swiglu.py#L55.

How about that I only modify the tests related to rms_norm, e,g. https://github.com/linkedin/Liger-Kernel/blob/main/test/transformers/test_rms_norm.py ?

@ByronHsu
Copy link
Collaborator

ByronHsu commented Nov 14, 2024

feel free to modify all tests and kernels to make all of them work on xpu as long as the original test still passes on cuda device.

We do want to support xpu but i don't have a device handy. How about we tackle rms_norm first, then we do others in the following PR? By the way, do you have wechat? Mine is wxid_nn8pbmlh9ae712. Eager to collaborate on xpu support!

@yao-matrix
Copy link

feel free to modify all tests and kernels to make all of them work on xpu as long as the original test still passes on cuda device.

We do want to support xpu but i don't have a device handy. How about we tackle rms_norm first, then we do others in the following PR? By the way, do you have wechat? Mine is wxid_nn8pbmlh9ae712. Eager to collaborate on xpu support!

Sure, we will do this. BTW, seems cannot find you w/ the wechat id you shared, maybe could you find us. Mine is: yaoweifeng813986, and @faaany can share hers. Thx.

@faaany
Copy link
Collaborator Author

faaany commented Nov 15, 2024

Hi @ByronHsu , coach_up is mine. And thanks for the proposal! I will update the rms_norm tests in this PR.

@ByronHsu
Copy link
Collaborator

I have added both of you on wechat. Please check!

@faaany
Copy link
Collaborator Author

faaany commented Nov 18, 2024

Hi @ByronHsu, I enabled the rms_norm UT on XPU. Pls see the below UT results:

=============================================== short test summary info ===============================================
PASSED test/transformers/test_rms_norm.py::test_correctness[True-LlamaRMSNorm-0.0-llama-dtype0-0.0001-1e-06-2-128-512]
PASSED test/transformers/test_rms_norm.py::test_correctness[True-LlamaRMSNorm-0.0-llama-dtype0-0.0001-1e-06-5-123-123]
PASSED test/transformers/test_rms_norm.py::test_correctness[True-LlamaRMSNorm-0.0-llama-dtype1-0.2-0.02-2-128-512]
PASSED test/transformers/test_rms_norm.py::test_correctness[True-LlamaRMSNorm-0.0-llama-dtype1-0.2-0.02-5-123-123]
PASSED test/transformers/test_rms_norm.py::test_correctness[True-GemmaRMSNorm-1.0-gemma-dtype0-0.0001-1e-06-2-128-512]
PASSED test/transformers/test_rms_norm.py::test_correctness[True-GemmaRMSNorm-1.0-gemma-dtype0-0.0001-1e-06-5-123-123]
PASSED test/transformers/test_rms_norm.py::test_correctness[True-GemmaRMSNorm-1.0-gemma-dtype1-0.2-0.02-2-128-512]
PASSED test/transformers/test_rms_norm.py::test_correctness[True-GemmaRMSNorm-1.0-gemma-dtype1-0.2-0.02-5-123-123]
PASSED test/transformers/test_rms_norm.py::test_correctness[False-LlamaRMSNorm-0.0-llama-dtype0-0.0001-1e-06-2-128-512]
PASSED test/transformers/test_rms_norm.py::test_correctness[False-LlamaRMSNorm-0.0-llama-dtype0-0.0001-1e-06-5-123-123]
PASSED test/transformers/test_rms_norm.py::test_correctness[False-LlamaRMSNorm-0.0-llama-dtype1-0.2-0.02-2-128-512]
PASSED test/transformers/test_rms_norm.py::test_correctness[False-LlamaRMSNorm-0.0-llama-dtype1-0.2-0.02-5-123-123]
PASSED test/transformers/test_rms_norm.py::test_correctness[False-GemmaRMSNorm-1.0-gemma-dtype0-0.0001-1e-06-2-128-512]
PASSED test/transformers/test_rms_norm.py::test_correctness[False-GemmaRMSNorm-1.0-gemma-dtype0-0.0001-1e-06-5-123-123]
PASSED test/transformers/test_rms_norm.py::test_correctness[False-GemmaRMSNorm-1.0-gemma-dtype1-0.2-0.02-2-128-512]
PASSED test/transformers/test_rms_norm.py::test_correctness[False-GemmaRMSNorm-1.0-gemma-dtype1-0.2-0.02-5-123-123]
PASSED test/transformers/test_rms_norm.py::test_correctness_functional[LlamaRMSNorm-0.0-llama-dtype0-0.0001-1e-06-2-2-8]
PASSED test/transformers/test_rms_norm.py::test_correctness_functional[LlamaRMSNorm-0.0-llama-dtype0-0.0001-1e-06-9-7-41]
PASSED test/transformers/test_rms_norm.py::test_correctness_functional[LlamaRMSNorm-0.0-llama-dtype1-0.2-0.02-2-2-8]
PASSED test/transformers/test_rms_norm.py::test_correctness_functional[LlamaRMSNorm-0.0-llama-dtype1-0.2-0.02-9-7-41]
PASSED test/transformers/test_rms_norm.py::test_correctness_functional[GemmaRMSNorm-1.0-gemma-dtype0-0.0001-1e-06-2-2-8]
PASSED test/transformers/test_rms_norm.py::test_correctness_functional[GemmaRMSNorm-1.0-gemma-dtype0-0.0001-1e-06-9-7-41]
PASSED test/transformers/test_rms_norm.py::test_correctness_functional[GemmaRMSNorm-1.0-gemma-dtype1-0.2-0.02-2-2-8]
PASSED test/transformers/test_rms_norm.py::test_correctness_functional[GemmaRMSNorm-1.0-gemma-dtype1-0.2-0.02-9-7-41]
=========================================== 24 passed, 2 warnings in 3.03s ============================================

Copy link
Collaborator

@ByronHsu ByronHsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Looking forward to the XPU support!

@ByronHsu ByronHsu merged commit 8e72763 into linkedin:main Nov 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants