Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCM] fixing build brakes and small refactoring #9367

Closed
wants to merge 2 commits into from

Conversation

pemeliya
Copy link
Contributor

@pemeliya pemeliya commented Feb 9, 2024

Here I fix recent buildbrakes on ROCM (redzone_allocator and topk_kernel missing include), and also perform some small refactoring by utilizing gpu_kernel_library macro which can build CUDA or ROCM GPU code

@xla-rotation: could you please have a look ?

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Feb 9, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Feb 9, 2024
@kamaljeeti kamaljeeti requested review from sgerrard and ddunl February 9, 2024 14:26
@i-chaochen
Copy link
Contributor

Hi @xla-rotation could you please review this PR? it's broken and blocked our build

@ddunl ddunl requested a review from ezhulenev February 12, 2024 18:47
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Feb 14, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Feb 14, 2024
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Feb 15, 2024
Imported from GitHub PR openxla/xla#9367

Here I fix recent buildbrakes on ROCM (redzone_allocator and topk_kernel missing include), and also perform some small refactoring by utilizing gpu_kernel_library macro which can build CUDA or ROCM GPU code

@xla-rotation: could you please have a look ?
Copybara import of the project:

--
09570c3eec55f4e9b054ac4d7fb413d89843139a by Pavel Emeliyanenko <pavel.emeliyanenko@amd.com>:

fixing build brakes and small refactoring

--
8c7f3bceec99fe42d18f2bfb4a5b0d5630073a6a by Pavel Emeliyanenko <pavel.emeliyanenko@amd.com>:

make sure the code builds without TF_HIPBLASLT

Merging this change closes #9367

PiperOrigin-RevId: 607224470
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 7, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
942f93f5187d3706bbe74d78766914ab84f33bb9 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

--
d8d45593ec4e40f391421d3a44890c6ae2d6f2f5 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support d8d45593ec4e40f391421d3a44890c6ae2d6f2f5
PiperOrigin-RevId: 610480934
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 8, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
942f93f5187d3706bbe74d78766914ab84f33bb9 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

--
d8d45593ec4e40f391421d3a44890c6ae2d6f2f5 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support d8d45593ec4e40f391421d3a44890c6ae2d6f2f5
PiperOrigin-RevId: 610480934
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 8, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
942f93f5187d3706bbe74d78766914ab84f33bb9 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

--
d8d45593ec4e40f391421d3a44890c6ae2d6f2f5 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support d8d45593ec4e40f391421d3a44890c6ae2d6f2f5
PiperOrigin-RevId: 610480934
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 11, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
942f93f5187d3706bbe74d78766914ab84f33bb9 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

--
d8d45593ec4e40f391421d3a44890c6ae2d6f2f5 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support d8d45593ec4e40f391421d3a44890c6ae2d6f2f5
PiperOrigin-RevId: 610480934
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 11, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
942f93f5187d3706bbe74d78766914ab84f33bb9 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

--
d8d45593ec4e40f391421d3a44890c6ae2d6f2f5 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support d8d45593ec4e40f391421d3a44890c6ae2d6f2f5
PiperOrigin-RevId: 610480934
jayfurmanek pushed a commit to ROCm/tensorflow-upstream that referenced this pull request Mar 11, 2024
Imported from GitHub PR openxla/xla#9367

Here I fix recent buildbrakes on ROCM (redzone_allocator and topk_kernel missing include), and also perform some small refactoring by utilizing gpu_kernel_library macro which can build CUDA or ROCM GPU code

@xla-rotation: could you please have a look ?
Copybara import of the project:

--
09570c3eec55f4e9b054ac4d7fb413d89843139a by Pavel Emeliyanenko <pavel.emeliyanenko@amd.com>:

fixing build brakes and small refactoring

--
8c7f3bceec99fe42d18f2bfb4a5b0d5630073a6a by Pavel Emeliyanenko <pavel.emeliyanenko@amd.com>:

make sure the code builds without TF_HIPBLASLT

Merging this change closes tensorflow#9367

PiperOrigin-RevId: 607224470
copybara-service bot pushed a commit that referenced this pull request Mar 13, 2024
Imported from GitHub PR #9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): #9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
942f93f by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

--
d8d4559 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=#9531 from ROCm:ci_fp8_gemm_support d8d4559
PiperOrigin-RevId: 615197463
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 13, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
942f93f5187d3706bbe74d78766914ab84f33bb9 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

--
d8d45593ec4e40f391421d3a44890c6ae2d6f2f5 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support d8d45593ec4e40f391421d3a44890c6ae2d6f2f5
PiperOrigin-RevId: 615197463
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 13, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
a4423f9136072f090494baff91bd87d3ab4d4069 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support a4423f9136072f090494baff91bd87d3ab4d4069
PiperOrigin-RevId: 610480934
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 13, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
a4423f9136072f090494baff91bd87d3ab4d4069 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support a4423f9136072f090494baff91bd87d3ab4d4069
PiperOrigin-RevId: 610480934
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 13, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
a4423f9136072f090494baff91bd87d3ab4d4069 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support a4423f9136072f090494baff91bd87d3ab4d4069
PiperOrigin-RevId: 610480934
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 13, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
a4423f9136072f090494baff91bd87d3ab4d4069 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support a4423f9136072f090494baff91bd87d3ab4d4069
PiperOrigin-RevId: 610480934
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 13, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
a4423f9136072f090494baff91bd87d3ab4d4069 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support a4423f9136072f090494baff91bd87d3ab4d4069
PiperOrigin-RevId: 610480934
copybara-service bot pushed a commit that referenced this pull request Mar 14, 2024
Imported from GitHub PR #9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): #9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
942f93f by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

--
d8d4559 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=#9531 from ROCm:ci_fp8_gemm_support a4423f9
PiperOrigin-RevId: 615197463
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 14, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
942f93f5187d3706bbe74d78766914ab84f33bb9 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

--
d8d45593ec4e40f391421d3a44890c6ae2d6f2f5 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support a4423f9136072f090494baff91bd87d3ab4d4069
PiperOrigin-RevId: 615197463
copybara-service bot pushed a commit that referenced this pull request Mar 14, 2024
Imported from GitHub PR #9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): #9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
942f93f by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

--
d8d4559 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=#9531 from ROCm:ci_fp8_gemm_support a4423f9
PiperOrigin-RevId: 615197463
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 14, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
942f93f5187d3706bbe74d78766914ab84f33bb9 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

--
d8d45593ec4e40f391421d3a44890c6ae2d6f2f5 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#9531 from ROCm:ci_fp8_gemm_support a4423f9136072f090494baff91bd87d3ab4d4069
PiperOrigin-RevId: 615197463
copybara-service bot pushed a commit that referenced this pull request Mar 14, 2024
Imported from GitHub PR #9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): #9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
942f93f by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

--
d8d4559 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

COPYBARA_INTEGRATE_REVIEW=#9531 from ROCm:ci_fp8_gemm_support a4423f9
PiperOrigin-RevId: 615649558
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 14, 2024
Imported from GitHub PR openxla/xla#9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): openxla/xla#9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
942f93f5187d3706bbe74d78766914ab84f33bb9 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

--
d8d45593ec4e40f391421d3a44890c6ae2d6f2f5 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

PiperOrigin-RevId: 615649558
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants