-
Notifications
You must be signed in to change notification settings - Fork 437
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fp8 matmul support on AMD MI300 #9531
Conversation
19b9255
to
b6df2c7
Compare
b6df2c7
to
cc0b48f
Compare
cc0b48f
to
f63e2f9
Compare
f63e2f9
to
61cfcb7
Compare
61cfcb7
to
48318ff
Compare
48318ff
to
942f93f
Compare
@ezhulenev @ddunl Could you please review this PR? |
I think I'm not the appropriate reviewer for this, @ezhulenev do you know who should review? |
I think @reedwm would be an appropriate reviewer for this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the long delay in reviewing!
xla/service/gpu/gemm_rewriter.cc
Outdated
if (instr->shape().element_type() == BF16 || | ||
instr->shape().element_type() == F8E4M3FNUZ || | ||
instr->shape().element_type() == F8E5M2FNUZ) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better to list the output types that are supported. E.g. something like
if (instr->shape().element_type() != F16 || instr->shape().element_type() != F32) {
...
}
Presumably there are many output types not supported like F64.
Also the documentation doesn't list FP8 as being supported. In the table at the bottom of the hipblasLtMatmul
documentation, only FP16, BF16, and FP32 are listed as supported. Is the documentation out of date?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the hipBlasLt documentation is not up to date... I just asked the hipBlasLt team to update it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed as suggested.
xla/service/gpu/gemm_rewriter.cc
Outdated
case F8E4M3FNUZ: | ||
case F8E5M2FNUZ: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You already ensure the output is not one of these dtypes by calling TurnF8DotWithUnsupportedOutputTypeIntoF32 above. And you don't want to allow these dtypes for the CUDA case. So I think you can remove these two dtypes here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right. These two cases are not necessary here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed as suggested.
case F8E4M3FNUZ: | ||
case F8E5M2FNUZ: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to my comment above, I think these can be removed as you don't support matmuls with F8 outputs in ROCM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is actually different from the case in gemm_rewriter. It is determining whether the matmul is supported by XLA before the gemm rewrite. If I remove them here, it will result in an internal error: "Unexpected GEMM datatype:".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@reedwm I have addressed most of the review comments. This one could not be removed as it would result in an error. Could you take a look at the changes and let me know if there are further concerns?
if (CudaOrRocmCheck(Switch::False, Switch::True)) { | ||
GTEST_SKIP() << "F8 gemm rewrite is not yet supported on ROCm platform"; | ||
} | ||
VLOG(-1) << "Running test " << |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've never seen VLOG(-1)
before. Note sure what this does. Maybe make this VLOG(1)
or remove?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this is unnecessary. I removed it.
static constexpr const char* kF8E4M3DatatypePlaceholder{ | ||
"<<F8E4M3_DATATYPE_PLACEHOLDER>>"}; | ||
static constexpr const char* kF8E5M2DatatypePlaceholder{ | ||
"<<F8E5M2_DATATYPE_PLACEHOLDER>>"}; | ||
static constexpr const char* kF8E4M3AmaxPlaceholder{ | ||
"<<F8E4M3_AMAX_PLACEHOLDER>>"}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This three placeholder names are quite long. I would instead use <<F8E4M3>>
, <<F8E5M2>>
, and <<F8E4M3_AMAX>>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed as suggested.
/* | ||
#if TENSORFLOW_USE_ROCM | ||
GTEST_SKIP() << "F8 gemm with Bias Gelu epilogue and bf16 output is not supported by hipBlasLt yet."; | ||
#endif // TENSORFLOW_USE_ROCM | ||
*/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this comment. And same in ScaledABUnscaledDApproxGeluActivationF8.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed as suggested. Should have removed them earlier...
if ((a_dtype == HIP_R_8F_E4M3_FNUZ || a_dtype == HIP_R_8F_E5M2_FNUZ) && | ||
(b_dtype == HIP_R_8F_E4M3_FNUZ || b_dtype == HIP_R_8F_E5M2_FNUZ)) { | ||
auto d_dtype = d_desc.type(); | ||
if (d_dtype == HIP_R_8F_E4M3_FNUZ || d_dtype == HIP_R_8F_E5M2_FNUZ) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought the output couldn't be FP8 in hipblasLT.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was trying to match the cublasLt bias types here and be a bit "future-proof" here as we will have hipblasLt support fp8 output in the future. But you're right, this condition is not necessary in hipblasLt's current capacity. I removed the unnecessary conditions.
if (c_scale != nullptr) { | ||
LOG(WARNING) << "c_scale is not nullptr."; | ||
} | ||
if (d_scale != nullptr) { | ||
LOG(WARNING) << "d_scale is not nullptr."; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return an internal error if c_scale or d_scale or nonnull. This should never happen and we do not handle it properly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There seems to be a typo in the comment. Do you mean we should return an internal error if c_scale or d_scale is nullptr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had a typo, I meant: "Return an internal error if c_scale or d_scale are nonnull". hipBlasLt cannot handle c_scale or d_scale so we should return an error if they are present.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clarification. I set the c_scale and d_scale to null on the ROCm side and will return an error if they are not null.
#if TENSORFLOW_USE_ROCM | ||
// Turns an F8 dot with output type BF16 or F8 into an F8 dot with F32 output, | ||
// and converting the F32 output to BF16 or F8. | ||
absl::StatusOr<HloInstruction *> TurnF8DotWithUnsupportedOutputTypeIntoF32( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a rocm-specific test that tests that the convert is inserted.
If you already added such a test, no need to do anything. I may have missed a test you modified or added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The FileCheck patterns in several tests were modified to account for the different rewrite behavior on ROCm (e.g.,
ScaledABUnscaledDVectorBiasThenApproxGeluActivationF8, ScaledABUnscaledDApproxGeluActivationF8, ScaledABScaledDF8, etc).
; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), | ||
; CHECK-PTX-NEXT: ROOT [[OUT:%[^ ]+]] = <<F8E4M3_DATATYPE_PLACEHOLDER>>[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cheshire, @ezhulenev, any opinions on using CHECK-GCN-NEXT and CHECK-PTX-NEXT like this? I'm unsure if this is a good idea or not. filecheck.cc
is modified to have CHECK-GCN-NEXT
on match for rocm and CHECK-PTX-NEXT
only match for CUDA, to account for slightly difference in how gemm_rewriter operates between the two.
f33a271
to
6967249
Compare
Imported from GitHub PR openxla/xla#9531 This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300. It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367 Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt: - hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8. - hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16. Copybara import of the project: -- 942f93f5187d3706bbe74d78766914ab84f33bb9 by Wen Chen <Wen.Chen@amd.com>: [ROCM] Initial support of fp8 Matmul via hipBlasLt. -- d8d45593ec4e40f391421d3a44890c6ae2d6f2f5 by Wen Chen <Wen.Chen@amd.com>: [ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support - Clean up unnecessary code, particularly regarding output types of fp8 - Override methods in ParameterizedFp8GemmRewriteTest to replace the patterns for CUDA and ROCm respectively for HLO checks. - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently does not support them. Merging this change closes #9531 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support d8d45593ec4e40f391421d3a44890c6ae2d6f2f5 PiperOrigin-RevId: 610480934
Imported from GitHub PR openxla/xla#9531 This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300. It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367 Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt: - hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8. - hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16. Copybara import of the project: -- 942f93f5187d3706bbe74d78766914ab84f33bb9 by Wen Chen <Wen.Chen@amd.com>: [ROCM] Initial support of fp8 Matmul via hipBlasLt. -- d8d45593ec4e40f391421d3a44890c6ae2d6f2f5 by Wen Chen <Wen.Chen@amd.com>: [ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support - Clean up unnecessary code, particularly regarding output types of fp8 - Override methods in ParameterizedFp8GemmRewriteTest to replace the patterns for CUDA and ROCm respectively for HLO checks. - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently does not support them. Merging this change closes #9531 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support d8d45593ec4e40f391421d3a44890c6ae2d6f2f5 PiperOrigin-RevId: 610480934
@ddunl I see there are some conflicts now. Is that blocking the merging? If so, I will rebase and resolve the conflicts. |
@wenchenvincent sorry for taking so long to merge. I will resolve conflicts internally and merge. |
Imported from GitHub PR #9531 This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300. It is based off this previous PR to fix build break on ROCm (which is still open): #9367 Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt: - hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8. - hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16. Copybara import of the project: -- 942f93f by Wen Chen <Wen.Chen@amd.com>: [ROCM] Initial support of fp8 Matmul via hipBlasLt. -- d8d4559 by Wen Chen <Wen.Chen@amd.com>: [ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support - Clean up unnecessary code, particularly regarding output types of fp8 - Override methods in ParameterizedFp8GemmRewriteTest to replace the patterns for CUDA and ROCm respectively for HLO checks. - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently does not support them. Merging this change closes #9531 FUTURE_COPYBARA_INTEGRATE_REVIEW=#9531 from ROCm:ci_fp8_gemm_support d8d4559 PiperOrigin-RevId: 615197463
Imported from GitHub PR openxla/xla#9531 This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300. It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367 Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt: - hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8. - hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16. Copybara import of the project: -- 942f93f5187d3706bbe74d78766914ab84f33bb9 by Wen Chen <Wen.Chen@amd.com>: [ROCM] Initial support of fp8 Matmul via hipBlasLt. -- d8d45593ec4e40f391421d3a44890c6ae2d6f2f5 by Wen Chen <Wen.Chen@amd.com>: [ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support - Clean up unnecessary code, particularly regarding output types of fp8 - Override methods in ParameterizedFp8GemmRewriteTest to replace the patterns for CUDA and ROCm respectively for HLO checks. - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently does not support them. Merging this change closes #9531 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support d8d45593ec4e40f391421d3a44890c6ae2d6f2f5 PiperOrigin-RevId: 615197463
@reedwm No worries. I had rebased and resolved the conflicts locally. But it seems that there were changes in the gemm_rewrite_test and the expected behavior of some tests have changed, so I was seeing some test failures. Let me resolve the failures on my side to make sure that it is good on the ROCm side after the rebase and then I will push the updates. |
[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support - Clean up unnecessary code, particularly regarding output types of fp8 - Override methods in ParameterizedFp8GemmRewriteTest to replace the patterns for CUDA and ROCm respectively for HLO checks. - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently does not support them.
d8d4559
to
a4423f9
Compare
Imported from GitHub PR openxla/xla#9531 This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300. It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367 Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt: - hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8. - hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16. Copybara import of the project: -- a4423f9136072f090494baff91bd87d3ab4d4069 by Wen Chen <Wen.Chen@amd.com>: [ROCM] Initial support of fp8 Matmul via hipBlasLt. [ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support - Clean up unnecessary code, particularly regarding output types of fp8 - Override methods in ParameterizedFp8GemmRewriteTest to replace the patterns for CUDA and ROCm respectively for HLO checks. - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently does not support them. Merging this change closes #9531 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support a4423f9136072f090494baff91bd87d3ab4d4069 PiperOrigin-RevId: 610480934
Imported from GitHub PR openxla/xla#9531 This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300. It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367 Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt: - hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8. - hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16. Copybara import of the project: -- a4423f9136072f090494baff91bd87d3ab4d4069 by Wen Chen <Wen.Chen@amd.com>: [ROCM] Initial support of fp8 Matmul via hipBlasLt. [ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support - Clean up unnecessary code, particularly regarding output types of fp8 - Override methods in ParameterizedFp8GemmRewriteTest to replace the patterns for CUDA and ROCm respectively for HLO checks. - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently does not support them. Merging this change closes #9531 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support a4423f9136072f090494baff91bd87d3ab4d4069 PiperOrigin-RevId: 610480934
Imported from GitHub PR openxla/xla#9531 This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300. It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367 Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt: - hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8. - hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16. Copybara import of the project: -- a4423f9136072f090494baff91bd87d3ab4d4069 by Wen Chen <Wen.Chen@amd.com>: [ROCM] Initial support of fp8 Matmul via hipBlasLt. [ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support - Clean up unnecessary code, particularly regarding output types of fp8 - Override methods in ParameterizedFp8GemmRewriteTest to replace the patterns for CUDA and ROCm respectively for HLO checks. - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently does not support them. Merging this change closes #9531 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support a4423f9136072f090494baff91bd87d3ab4d4069 PiperOrigin-RevId: 610480934
Imported from GitHub PR openxla/xla#9531 This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300. It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367 Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt: - hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8. - hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16. Copybara import of the project: -- a4423f9136072f090494baff91bd87d3ab4d4069 by Wen Chen <Wen.Chen@amd.com>: [ROCM] Initial support of fp8 Matmul via hipBlasLt. [ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support - Clean up unnecessary code, particularly regarding output types of fp8 - Override methods in ParameterizedFp8GemmRewriteTest to replace the patterns for CUDA and ROCm respectively for HLO checks. - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently does not support them. Merging this change closes #9531 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support a4423f9136072f090494baff91bd87d3ab4d4069 PiperOrigin-RevId: 610480934
Imported from GitHub PR openxla/xla#9531 This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300. It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367 Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt: - hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8. - hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16. Copybara import of the project: -- a4423f9136072f090494baff91bd87d3ab4d4069 by Wen Chen <Wen.Chen@amd.com>: [ROCM] Initial support of fp8 Matmul via hipBlasLt. [ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support - Clean up unnecessary code, particularly regarding output types of fp8 - Override methods in ParameterizedFp8GemmRewriteTest to replace the patterns for CUDA and ROCm respectively for HLO checks. - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently does not support them. Merging this change closes #9531 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support a4423f9136072f090494baff91bd87d3ab4d4069 PiperOrigin-RevId: 610480934
Imported from GitHub PR #9531 This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300. It is based off this previous PR to fix build break on ROCm (which is still open): #9367 Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt: - hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8. - hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16. Copybara import of the project: -- 942f93f by Wen Chen <Wen.Chen@amd.com>: [ROCM] Initial support of fp8 Matmul via hipBlasLt. -- d8d4559 by Wen Chen <Wen.Chen@amd.com>: [ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support - Clean up unnecessary code, particularly regarding output types of fp8 - Override methods in ParameterizedFp8GemmRewriteTest to replace the patterns for CUDA and ROCm respectively for HLO checks. - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently does not support them. Merging this change closes #9531 FUTURE_COPYBARA_INTEGRATE_REVIEW=#9531 from ROCm:ci_fp8_gemm_support a4423f9 PiperOrigin-RevId: 615197463
Imported from GitHub PR openxla/xla#9531 This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300. It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367 Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt: - hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8. - hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16. Copybara import of the project: -- 942f93f5187d3706bbe74d78766914ab84f33bb9 by Wen Chen <Wen.Chen@amd.com>: [ROCM] Initial support of fp8 Matmul via hipBlasLt. -- d8d45593ec4e40f391421d3a44890c6ae2d6f2f5 by Wen Chen <Wen.Chen@amd.com>: [ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support - Clean up unnecessary code, particularly regarding output types of fp8 - Override methods in ParameterizedFp8GemmRewriteTest to replace the patterns for CUDA and ROCm respectively for HLO checks. - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently does not support them. Merging this change closes #9531 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support a4423f9136072f090494baff91bd87d3ab4d4069 PiperOrigin-RevId: 615197463
Imported from GitHub PR #9531 This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300. It is based off this previous PR to fix build break on ROCm (which is still open): #9367 Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt: - hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8. - hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16. Copybara import of the project: -- 942f93f by Wen Chen <Wen.Chen@amd.com>: [ROCM] Initial support of fp8 Matmul via hipBlasLt. -- d8d4559 by Wen Chen <Wen.Chen@amd.com>: [ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support - Clean up unnecessary code, particularly regarding output types of fp8 - Override methods in ParameterizedFp8GemmRewriteTest to replace the patterns for CUDA and ROCm respectively for HLO checks. - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently does not support them. Merging this change closes #9531 FUTURE_COPYBARA_INTEGRATE_REVIEW=#9531 from ROCm:ci_fp8_gemm_support a4423f9 PiperOrigin-RevId: 615197463
Imported from GitHub PR openxla/xla#9531 This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300. It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367 Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt: - hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8. - hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16. Copybara import of the project: -- 942f93f5187d3706bbe74d78766914ab84f33bb9 by Wen Chen <Wen.Chen@amd.com>: [ROCM] Initial support of fp8 Matmul via hipBlasLt. -- d8d45593ec4e40f391421d3a44890c6ae2d6f2f5 by Wen Chen <Wen.Chen@amd.com>: [ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support - Clean up unnecessary code, particularly regarding output types of fp8 - Override methods in ParameterizedFp8GemmRewriteTest to replace the patterns for CUDA and ROCm respectively for HLO checks. - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently does not support them. Merging this change closes #9531 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support a4423f9136072f090494baff91bd87d3ab4d4069 PiperOrigin-RevId: 615197463
Imported from GitHub PR openxla/xla#9531 This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300. It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367 Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt: - hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8. - hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16. Copybara import of the project: -- 942f93f5187d3706bbe74d78766914ab84f33bb9 by Wen Chen <Wen.Chen@amd.com>: [ROCM] Initial support of fp8 Matmul via hipBlasLt. -- d8d45593ec4e40f391421d3a44890c6ae2d6f2f5 by Wen Chen <Wen.Chen@amd.com>: [ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support - Clean up unnecessary code, particularly regarding output types of fp8 - Override methods in ParameterizedFp8GemmRewriteTest to replace the patterns for CUDA and ROCm respectively for HLO checks. - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently does not support them. Merging this change closes #9531 PiperOrigin-RevId: 615649558
There are several different genres of fp8 formats used by different HW vendors. Two popular genres include - OCP fp8, which is used natively on NVIDIA H100 - NANOO fp8, which is used natively on AMD MI300 and Graphcore HW. These two genres of fp8 formats work very similarly. This PR is to enable support of NANOO fp8 as it is also now supported in JAX and XLA. References: - OCP fp8 paper: https://arxiv.org/abs/2209.05433 - NANOO fp8 paper: https://arxiv.org/abs/2206.02915 - JAX PR: jax-ml/jax#21376 - XLA PR: openxla/xla#9531
There are several different genres of fp8 formats used by different HW vendors. Two popular genres include - OCP fp8, which is used natively on NVIDIA H100 - NANOO fp8, which is used natively on AMD MI300 and Graphcore HW. These two genres of fp8 formats work very similarly. This PR is to enable support of NANOO fp8 as it is also now supported in JAX and XLA. References: - OCP fp8 paper: https://arxiv.org/abs/2209.05433 - NANOO fp8 paper: https://arxiv.org/abs/2206.02915 - JAX PR: jax-ml/jax#21376 - XLA PR: openxla/xla#9531
This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): #9367
Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt: