Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fp8 matmul support on AMD MI300 #9531

Closed
wants to merge 1 commit into from

Conversation

wenchenvincent
Copy link
Contributor

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): #9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:

  • hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
  • hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Feb 15, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Feb 15, 2024
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Feb 16, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Feb 16, 2024
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Feb 16, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Feb 16, 2024
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Feb 20, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Feb 20, 2024
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Feb 20, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Feb 20, 2024
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Feb 21, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Feb 21, 2024
@wenchenvincent
Copy link
Contributor Author

@ezhulenev @ddunl Could you please review this PR?

@ddunl
Copy link
Member

ddunl commented Feb 21, 2024

I think I'm not the appropriate reviewer for this, @ezhulenev do you know who should review?

@wenchenvincent
Copy link
Contributor Author

I think I'm not the appropriate reviewer for this, @ezhulenev do you know who should review?

I think @reedwm would be an appropriate reviewer for this PR.

Copy link
Member

@reedwm reedwm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the long delay in reviewing!

Comment on lines 536 to 538
if (instr->shape().element_type() == BF16 ||
instr->shape().element_type() == F8E4M3FNUZ ||
instr->shape().element_type() == F8E5M2FNUZ) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to list the output types that are supported. E.g. something like

if (instr->shape().element_type() != F16 || instr->shape().element_type() != F32) {
   ...
}

Presumably there are many output types not supported like F64.

Also the documentation doesn't list FP8 as being supported. In the table at the bottom of the hipblasLtMatmul documentation, only FP16, BF16, and FP32 are listed as supported. Is the documentation out of date?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the hipBlasLt documentation is not up to date... I just asked the hipBlasLt team to update it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed as suggested.

Comment on lines 965 to 966
case F8E4M3FNUZ:
case F8E5M2FNUZ:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You already ensure the output is not one of these dtypes by calling TurnF8DotWithUnsupportedOutputTypeIntoF32 above. And you don't want to allow these dtypes for the CUDA case. So I think you can remove these two dtypes here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. These two cases are not necessary here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed as suggested.

Comment on lines +388 to +393
case F8E4M3FNUZ:
case F8E5M2FNUZ:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to my comment above, I think these can be removed as you don't support matmuls with F8 outputs in ROCM.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually different from the case in gemm_rewriter. It is determining whether the matmul is supported by XLA before the gemm rewrite. If I remove them here, it will result in an internal error: "Unexpected GEMM datatype:".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@reedwm I have addressed most of the review comments. This one could not be removed as it would result in an error. Could you take a look at the changes and let me know if there are further concerns?

if (CudaOrRocmCheck(Switch::False, Switch::True)) {
GTEST_SKIP() << "F8 gemm rewrite is not yet supported on ROCm platform";
}
VLOG(-1) << "Running test " <<
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've never seen VLOG(-1) before. Note sure what this does. Maybe make this VLOG(1) or remove?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is unnecessary. I removed it.

Comment on lines 4592 to 4597
static constexpr const char* kF8E4M3DatatypePlaceholder{
"<<F8E4M3_DATATYPE_PLACEHOLDER>>"};
static constexpr const char* kF8E5M2DatatypePlaceholder{
"<<F8E5M2_DATATYPE_PLACEHOLDER>>"};
static constexpr const char* kF8E4M3AmaxPlaceholder{
"<<F8E4M3_AMAX_PLACEHOLDER>>"};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This three placeholder names are quite long. I would instead use <<F8E4M3>>, <<F8E5M2>>, and <<F8E4M3_AMAX>>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed as suggested.

Comment on lines 5369 to 5373
/*
#if TENSORFLOW_USE_ROCM
GTEST_SKIP() << "F8 gemm with Bias Gelu epilogue and bf16 output is not supported by hipBlasLt yet.";
#endif // TENSORFLOW_USE_ROCM
*/
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this comment. And same in ScaledABUnscaledDApproxGeluActivationF8.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed as suggested. Should have removed them earlier...

if ((a_dtype == HIP_R_8F_E4M3_FNUZ || a_dtype == HIP_R_8F_E5M2_FNUZ) &&
(b_dtype == HIP_R_8F_E4M3_FNUZ || b_dtype == HIP_R_8F_E5M2_FNUZ)) {
auto d_dtype = d_desc.type();
if (d_dtype == HIP_R_8F_E4M3_FNUZ || d_dtype == HIP_R_8F_E5M2_FNUZ) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the output couldn't be FP8 in hipblasLT.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to match the cublasLt bias types here and be a bit "future-proof" here as we will have hipblasLt support fp8 output in the future. But you're right, this condition is not necessary in hipblasLt's current capacity. I removed the unnecessary conditions.

Comment on lines 431 to 429
if (c_scale != nullptr) {
LOG(WARNING) << "c_scale is not nullptr.";
}
if (d_scale != nullptr) {
LOG(WARNING) << "d_scale is not nullptr.";
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return an internal error if c_scale or d_scale or nonnull. This should never happen and we do not handle it properly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be a typo in the comment. Do you mean we should return an internal error if c_scale or d_scale is nullptr?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a typo, I meant: "Return an internal error if c_scale or d_scale are nonnull". hipBlasLt cannot handle c_scale or d_scale so we should return an error if they are present.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification. I set the c_scale and d_scale to null on the ROCm side and will return an error if they are not null.

#if TENSORFLOW_USE_ROCM
// Turns an F8 dot with output type BF16 or F8 into an F8 dot with F32 output,
// and converting the F32 output to BF16 or F8.
absl::StatusOr<HloInstruction *> TurnF8DotWithUnsupportedOutputTypeIntoF32(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a rocm-specific test that tests that the convert is inserted.

If you already added such a test, no need to do anything. I may have missed a test you modified or added.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The FileCheck patterns in several tests were modified to account for the different rewrite behavior on ROCm (e.g.,
ScaledABUnscaledDVectorBiasThenApproxGeluActivationF8, ScaledABUnscaledDApproxGeluActivationF8, ScaledABScaledDF8, etc).

Comment on lines 4732 to 4733
; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]),
; CHECK-PTX-NEXT: ROOT [[OUT:%[^ ]+]] = <<F8E4M3_DATATYPE_PLACEHOLDER>>[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cheshire, @ezhulenev, any opinions on using CHECK-GCN-NEXT and CHECK-PTX-NEXT like this? I'm unsure if this is a good idea or not. filecheck.cc is modified to have CHECK-GCN-NEXT on match for rocm and CHECK-PTX-NEXT only match for CUDA, to account for slightly difference in how gemm_rewriter operates between the two.

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Feb 24, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Feb 24, 2024
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 11, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
942f93f5187d3706bbe74d78766914ab84f33bb9 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

--
d8d45593ec4e40f391421d3a44890c6ae2d6f2f5 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support d8d45593ec4e40f391421d3a44890c6ae2d6f2f5
PiperOrigin-RevId: 610480934
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 11, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
942f93f5187d3706bbe74d78766914ab84f33bb9 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

--
d8d45593ec4e40f391421d3a44890c6ae2d6f2f5 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support d8d45593ec4e40f391421d3a44890c6ae2d6f2f5
PiperOrigin-RevId: 610480934
@wenchenvincent
Copy link
Contributor Author

@ddunl I see there are some conflicts now. Is that blocking the merging? If so, I will rebase and resolve the conflicts.

@reedwm
Copy link
Member

reedwm commented Mar 12, 2024

@wenchenvincent sorry for taking so long to merge. I will resolve conflicts internally and merge.

copybara-service bot pushed a commit that referenced this pull request Mar 13, 2024
Imported from GitHub PR #9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): #9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
942f93f by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

--
d8d4559 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=#9531 from ROCm:ci_fp8_gemm_support d8d4559
PiperOrigin-RevId: 615197463
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 13, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
942f93f5187d3706bbe74d78766914ab84f33bb9 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

--
d8d45593ec4e40f391421d3a44890c6ae2d6f2f5 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support d8d45593ec4e40f391421d3a44890c6ae2d6f2f5
PiperOrigin-RevId: 615197463
@wenchenvincent
Copy link
Contributor Author

wenchenvincent commented Mar 13, 2024

@wenchenvincent sorry for taking so long to merge. I will resolve conflicts internally and merge.

@reedwm No worries. I had rebased and resolved the conflicts locally. But it seems that there were changes in the gemm_rewrite_test and the expected behavior of some tests have changed, so I was seeing some test failures. Let me resolve the failures on my side to make sure that it is good on the ROCm side after the rebase and then I will push the updates.

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Mar 13, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Mar 13, 2024
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 13, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
a4423f9136072f090494baff91bd87d3ab4d4069 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support a4423f9136072f090494baff91bd87d3ab4d4069
PiperOrigin-RevId: 610480934
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 13, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
a4423f9136072f090494baff91bd87d3ab4d4069 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support a4423f9136072f090494baff91bd87d3ab4d4069
PiperOrigin-RevId: 610480934
@wenchenvincent
Copy link
Contributor Author

@reedwm @ddunl I pushed the update with conflict resolution. Please let me know what's remaining to get this PR merged. Thanks!

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 13, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
a4423f9136072f090494baff91bd87d3ab4d4069 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support a4423f9136072f090494baff91bd87d3ab4d4069
PiperOrigin-RevId: 610480934
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 13, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
a4423f9136072f090494baff91bd87d3ab4d4069 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support a4423f9136072f090494baff91bd87d3ab4d4069
PiperOrigin-RevId: 610480934
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 13, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
a4423f9136072f090494baff91bd87d3ab4d4069 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support a4423f9136072f090494baff91bd87d3ab4d4069
PiperOrigin-RevId: 610480934
copybara-service bot pushed a commit that referenced this pull request Mar 14, 2024
Imported from GitHub PR #9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): #9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
942f93f by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

--
d8d4559 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=#9531 from ROCm:ci_fp8_gemm_support a4423f9
PiperOrigin-RevId: 615197463
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 14, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
942f93f5187d3706bbe74d78766914ab84f33bb9 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

--
d8d45593ec4e40f391421d3a44890c6ae2d6f2f5 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support a4423f9136072f090494baff91bd87d3ab4d4069
PiperOrigin-RevId: 615197463
copybara-service bot pushed a commit that referenced this pull request Mar 14, 2024
Imported from GitHub PR #9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): #9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
942f93f by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

--
d8d4559 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=#9531 from ROCm:ci_fp8_gemm_support a4423f9
PiperOrigin-RevId: 615197463
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 14, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
942f93f5187d3706bbe74d78766914ab84f33bb9 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

--
d8d45593ec4e40f391421d3a44890c6ae2d6f2f5 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support a4423f9136072f090494baff91bd87d3ab4d4069
PiperOrigin-RevId: 615197463
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 14, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
942f93f5187d3706bbe74d78766914ab84f33bb9 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

--
d8d45593ec4e40f391421d3a44890c6ae2d6f2f5 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

PiperOrigin-RevId: 615649558
wenchenvincent added a commit to wenchenvincent/flax that referenced this pull request Jun 13, 2024
There are several different genres of fp8 formats used by different
HW vendors. Two popular genres include
- OCP fp8, which is used natively on NVIDIA H100
- NANOO fp8, which is used natively on AMD MI300 and Graphcore HW.

These two genres of fp8 formats work very similarly. This PR is to
enable support of NANOO fp8 as it is also now supported in JAX and XLA.

References:
- OCP fp8 paper: https://arxiv.org/abs/2209.05433
- NANOO fp8 paper: https://arxiv.org/abs/2206.02915
- JAX PR: jax-ml/jax#21376
- XLA PR: openxla/xla#9531
wenchenvincent added a commit to wenchenvincent/flax that referenced this pull request Jun 28, 2024
There are several different genres of fp8 formats used by different
HW vendors. Two popular genres include
- OCP fp8, which is used natively on NVIDIA H100
- NANOO fp8, which is used natively on AMD MI300 and Graphcore HW.

These two genres of fp8 formats work very similarly. This PR is to
enable support of NANOO fp8 as it is also now supported in JAX and XLA.

References:
- OCP fp8 paper: https://arxiv.org/abs/2209.05433
- NANOO fp8 paper: https://arxiv.org/abs/2206.02915
- JAX PR: jax-ml/jax#21376
- XLA PR: openxla/xla#9531
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants