-
Notifications
You must be signed in to change notification settings - Fork 229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add xpu device support for rms_norm
#379
Conversation
cool! what it takes to fix all kernels and make all tests pass? |
then we will need to make all tests device-agnostic. Currently, I see that the device is hard-coded in the test itself, e.g. https://github.com/linkedin/Liger-Kernel/blob/main/test/transformers/test_swiglu.py#L55. How about that I only modify the tests related to rms_norm, e,g. https://github.com/linkedin/Liger-Kernel/blob/main/test/transformers/test_rms_norm.py ? |
feel free to modify all tests and kernels to make all of them work on xpu as long as the original test still passes on cuda device. We do want to support xpu but i don't have a device handy. How about we tackle rms_norm first, then we do others in the following PR? By the way, do you have wechat? Mine is wxid_nn8pbmlh9ae712. Eager to collaborate on xpu support! |
Sure, we will do this. BTW, seems cannot find you w/ the wechat id you shared, maybe could you find us. Mine is: yaoweifeng813986, and @faaany can share hers. Thx. |
Hi @ByronHsu , coach_up is mine. And thanks for the proposal! I will update the rms_norm tests in this PR. |
I have added both of you on wechat. Please check! |
Hi @ByronHsu, I enabled the rms_norm UT on XPU. Pls see the below UT results: =============================================== short test summary info ===============================================
PASSED test/transformers/test_rms_norm.py::test_correctness[True-LlamaRMSNorm-0.0-llama-dtype0-0.0001-1e-06-2-128-512]
PASSED test/transformers/test_rms_norm.py::test_correctness[True-LlamaRMSNorm-0.0-llama-dtype0-0.0001-1e-06-5-123-123]
PASSED test/transformers/test_rms_norm.py::test_correctness[True-LlamaRMSNorm-0.0-llama-dtype1-0.2-0.02-2-128-512]
PASSED test/transformers/test_rms_norm.py::test_correctness[True-LlamaRMSNorm-0.0-llama-dtype1-0.2-0.02-5-123-123]
PASSED test/transformers/test_rms_norm.py::test_correctness[True-GemmaRMSNorm-1.0-gemma-dtype0-0.0001-1e-06-2-128-512]
PASSED test/transformers/test_rms_norm.py::test_correctness[True-GemmaRMSNorm-1.0-gemma-dtype0-0.0001-1e-06-5-123-123]
PASSED test/transformers/test_rms_norm.py::test_correctness[True-GemmaRMSNorm-1.0-gemma-dtype1-0.2-0.02-2-128-512]
PASSED test/transformers/test_rms_norm.py::test_correctness[True-GemmaRMSNorm-1.0-gemma-dtype1-0.2-0.02-5-123-123]
PASSED test/transformers/test_rms_norm.py::test_correctness[False-LlamaRMSNorm-0.0-llama-dtype0-0.0001-1e-06-2-128-512]
PASSED test/transformers/test_rms_norm.py::test_correctness[False-LlamaRMSNorm-0.0-llama-dtype0-0.0001-1e-06-5-123-123]
PASSED test/transformers/test_rms_norm.py::test_correctness[False-LlamaRMSNorm-0.0-llama-dtype1-0.2-0.02-2-128-512]
PASSED test/transformers/test_rms_norm.py::test_correctness[False-LlamaRMSNorm-0.0-llama-dtype1-0.2-0.02-5-123-123]
PASSED test/transformers/test_rms_norm.py::test_correctness[False-GemmaRMSNorm-1.0-gemma-dtype0-0.0001-1e-06-2-128-512]
PASSED test/transformers/test_rms_norm.py::test_correctness[False-GemmaRMSNorm-1.0-gemma-dtype0-0.0001-1e-06-5-123-123]
PASSED test/transformers/test_rms_norm.py::test_correctness[False-GemmaRMSNorm-1.0-gemma-dtype1-0.2-0.02-2-128-512]
PASSED test/transformers/test_rms_norm.py::test_correctness[False-GemmaRMSNorm-1.0-gemma-dtype1-0.2-0.02-5-123-123]
PASSED test/transformers/test_rms_norm.py::test_correctness_functional[LlamaRMSNorm-0.0-llama-dtype0-0.0001-1e-06-2-2-8]
PASSED test/transformers/test_rms_norm.py::test_correctness_functional[LlamaRMSNorm-0.0-llama-dtype0-0.0001-1e-06-9-7-41]
PASSED test/transformers/test_rms_norm.py::test_correctness_functional[LlamaRMSNorm-0.0-llama-dtype1-0.2-0.02-2-2-8]
PASSED test/transformers/test_rms_norm.py::test_correctness_functional[LlamaRMSNorm-0.0-llama-dtype1-0.2-0.02-9-7-41]
PASSED test/transformers/test_rms_norm.py::test_correctness_functional[GemmaRMSNorm-1.0-gemma-dtype0-0.0001-1e-06-2-2-8]
PASSED test/transformers/test_rms_norm.py::test_correctness_functional[GemmaRMSNorm-1.0-gemma-dtype0-0.0001-1e-06-9-7-41]
PASSED test/transformers/test_rms_norm.py::test_correctness_functional[GemmaRMSNorm-1.0-gemma-dtype1-0.2-0.02-2-2-8]
PASSED test/transformers/test_rms_norm.py::test_correctness_functional[GemmaRMSNorm-1.0-gemma-dtype1-0.2-0.02-9-7-41]
=========================================== 24 passed, 2 warnings in 3.03s ============================================ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Looking forward to the XPU support!
Summary
I was running a trl unit test with liger support (link) and found that cuda device is hard-coded in
rms_norm_backward
.This PR adds support for Intel GPU. After the fix, the test passes:
Testing Done
A lot of tests fail because it only support CUDA devices.