Skip to content

Commit

Permalink
PR #9531: Fp8 matmul support on AMD MI300
Browse files Browse the repository at this point in the history
Imported from GitHub PR #9531

This PR enables support of fp8 matmul (with fp8e4m3fnuz and fp8e5m2fnuz) on MI300.
It is based off this previous PR to fix build break on ROCm (which is still open): #9367

Some of the GEMM pattern rewrites were changed to accommodate current limitations in hipblasLt:
- hipblasLt currently does not support fp8 matmul with fp8 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to fp8.
- hipblasLt currently does not support fp8 matmul with bf16 output. Therefore, it is rewritten to custom call of fp8 matmul with fp32 output, and HLO instructions are used to convert fp32 output to bf16.

Copybara import of the project:

--
942f93f by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Initial support of fp8 Matmul via hipBlasLt.

--
d8d4559 by Wen Chen <Wen.Chen@amd.com>:

[ROCM] Code refactoring of initial hipBlasLt fp8 Matmul support

 - Clean up unnecessary code, particularly regarding output types of fp8

 - Override methods in ParameterizedFp8GemmRewriteTest to replace the
   patterns for CUDA and ROCm respectively for HLO checks.

 - Explicitly set c_scale and d_scale to nullptr as hipblasLt currently
   does not support them.

Merging this change closes #9531

FUTURE_COPYBARA_INTEGRATE_REVIEW=#9531 from ROCm:ci_fp8_gemm_support a4423f9
PiperOrigin-RevId: 615197463
  • Loading branch information
wenchenvincent authored and copybara-github committed Mar 14, 2024
1 parent 85e5198 commit 6abe18a
Show file tree
Hide file tree
Showing 14 changed files with 993 additions and 439 deletions.
112 changes: 103 additions & 9 deletions xla/service/gpu/gemm_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ bool IsSupportedF8Pattern(
}

std::reverse(subgraph.begin(), subgraph.end());

// When not operating directly on an FP8 operand, the second and
// third instructions in the subgraph must describe a dequantization, i.e. a
// convert instruction followed by a multiply/divide instruction.
Expand Down Expand Up @@ -569,6 +568,13 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
const_cast<HloInstruction *>(instr), b, b_scale,
b_mult_scale, b_ops);
})))) {
#if TENSORFLOW_USE_ROCM
if (instr->shape().element_type() != F16 &&
instr->shape().element_type() != F32) {
TF_ASSIGN_OR_RETURN(instr,
TurnF8DotWithUnsupportedOutputTypeIntoF32(instr));
}
#endif // TENSORFLOW_USE_ROCM
TF_ASSIGN_OR_RETURN(bool created_call,
CreateF8CustomCall(instr, gpu_backend_config, a, b,
a_scale, b_scale, a_mult_scale,
Expand Down Expand Up @@ -875,9 +881,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
HloInstruction *b_scale, bool a_mult_scale, bool b_mult_scale,
std::vector<std::pair<HloInstruction *, int>> a_ops,
std::vector<std::pair<HloInstruction *, int>> b_ops) {
#if GOOGLE_CUDA
GemmBackendConfig &gemm_backend_config =
*gpu_backend_config.mutable_gemm_backend_config();
#if GOOGLE_CUDA
auto cuda_compute_capability_ =
std::get<se::CudaComputeCapability>(gpu_version_);
// FP8 GEMM kernels are only available on Ada, Hopper, and later
Expand All @@ -887,17 +893,36 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
<< "FP8 Custom Calls require Ada, Hopper, or later architectures.";
return false;
}

#if CUDA_VERSION < 12000
// FP8 GEMM kernels are only available with CUDA 12.0 and above
VLOG(1) << "FP8 Custom Calls require CUDA 12.0 or newer.";
return false;
#endif // CUDA_VERSION < 12000

#endif // GOOGLE_CUDA

#if TENSORFLOW_USE_ROCM
auto isrocm = std::get_if<se::RocmComputeCapability>(&gpu_version_);
if (!isrocm->has_fp8_support()) {
VLOG(1) << "FP8 Custom Calls require MI300, or later architectures.";
return false;
}

#if TF_ROCM_VERSION < 60000
// FP8 GEMM kernels are only available with ROCm 6.0 and above
VLOG(1) << "FP8 Custom Calls require ROCm 6.0 or newer.";
return false;
#endif // TF_ROCM_VERSION < 60000

#endif // TENSORFLOW_USE_ROCM

PrimitiveType a_type = a->shape().element_type();
PrimitiveType b_type = b->shape().element_type();

// cuBLASLt FP8 GEMM kernels require one of the two operands to be in
// F8E4M3FN format.
#if GOOGLE_CUDA
if (a_type == F8E5M2 && b_type == F8E5M2) {
VLOG(1)
<< "Failed to rewrite " << instr->ToShortString()
Expand All @@ -914,6 +939,26 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
<< PrimitiveType_Name(b_type);
return false;
}
#endif // GOOGLE_CUDA

#if TENSORFLOW_USE_ROCM
if (a_type == F8E5M2FNUZ && b_type == F8E5M2FNUZ) {
VLOG(1)
<< "Failed to rewrite " << instr->ToShortString()
<< " into FP8 Custom Call. The element type of one of the operands "
"must be F8E4M3FNUZ.";
return false;
}
if ((a_type != F8E5M2FNUZ && a_type != F8E4M3FNUZ) ||
(b_type != F8E5M2FNUZ && b_type != F8E4M3FNUZ)) {
VLOG(1) << "Failed to rewrite " << instr->ToShortString()
<< " into FP8 Custom Call. The input types must be F8E5M2FNUZ or "
"F8E4M3FNUZ, but got "
<< PrimitiveType_Name(a_type) << " and "
<< PrimitiveType_Name(b_type);
return false;
}
#endif // TENSORFLOW_USE_ROCM

absl::Span<const int64_t> batch_dims =
gemm_backend_config.dot_dimension_numbers().rhs_batch_dimensions();
Expand Down Expand Up @@ -956,6 +1001,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
case F32:
break;
default:

VLOG(1) << "Failed to rewrite " << instr->ToShortString()
<< " into FP8 Custom Call. Output element type must be "
"F8E4M3FN, F8E5M2, BF16, F16 or F32. Actual element type is "
Expand Down Expand Up @@ -1104,9 +1150,6 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
ReplaceInstruction(instr, slice ? slice : new_custom_call));
VLOG(1) << instr->ToString() << " rewritten into FP8 Custom Call.";
return true;
#else // TENSORFLOW_USE_ROCM
return false;
#endif
}

absl::Status F8ConvertD(HloInstruction *instr, HloInstruction *existing_gemm,
Expand Down Expand Up @@ -1741,10 +1784,12 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
const PrimitiveType output_type =
bias ? bias->shape().element_type() : instr.shape().element_type();
const std::array<PrimitiveType, 12> supported_type = {
PrimitiveType::F8E5M2, PrimitiveType::F8E4M3FN, PrimitiveType::S8,
PrimitiveType::F16, PrimitiveType::BF16, PrimitiveType::F32,
PrimitiveType::S32, PrimitiveType::F64, PrimitiveType::C64,
PrimitiveType::C128};
PrimitiveType::F8E5M2FNUZ, PrimitiveType::F8E4M3FNUZ,
PrimitiveType::F8E5M2, PrimitiveType::F8E4M3FN,
PrimitiveType::S8, PrimitiveType::F16,
PrimitiveType::BF16, PrimitiveType::F32,
PrimitiveType::S32, PrimitiveType::F64,
PrimitiveType::C64, PrimitiveType::C128};
if (!absl::c_linear_search(supported_type, output_type)) return false;
// cublasLt has a defined set of combinations of types that it supports.
// Figure out the computeType and scaleType.
Expand Down Expand Up @@ -1807,6 +1852,39 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
PrimitiveType::F8E4M3FN, DataType::kFloat},
#endif // GOOGLE_CUDA
#if TENSORFLOW_USE_ROCM
// FP8 types:
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
PrimitiveType::F8E4M3FNUZ, DataType::kBF16},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
PrimitiveType::F8E4M3FNUZ, DataType::kF8E4M3FNUZ},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
PrimitiveType::F8E4M3FNUZ, DataType::kHalf},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
PrimitiveType::F8E4M3FNUZ, DataType::kFloat},

{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
PrimitiveType::F8E5M2FNUZ, DataType::kBF16},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
PrimitiveType::F8E5M2FNUZ, DataType::kF8E4M3FNUZ},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
PrimitiveType::F8E5M2FNUZ, DataType::kF8E5M2FNUZ},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
PrimitiveType::F8E5M2FNUZ, DataType::kHalf},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
PrimitiveType::F8E5M2FNUZ, DataType::kFloat},

{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
PrimitiveType::F8E4M3FNUZ, DataType::kBF16},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
PrimitiveType::F8E4M3FNUZ, DataType::kF8E4M3FNUZ},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
PrimitiveType::F8E4M3FNUZ, DataType::kF8E5M2FNUZ},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
PrimitiveType::F8E4M3FNUZ, DataType::kHalf},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
PrimitiveType::F8E4M3FNUZ, DataType::kFloat},
#endif // TENSORFLOW_USE_ROCM
// Other data types:
{ComputationType::kF16, DataType::kHalf, PrimitiveType::F16,
PrimitiveType::F16, DataType::kHalf},
Expand Down Expand Up @@ -1993,6 +2071,22 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
return lhs_non_contracting_dimension_size <= kMaxDimensionSize;
}

#if TENSORFLOW_USE_ROCM
// Turns an F8 dot with unsupported output type into an F8 dot with F32
// output, and converting the F32 output to unsupported output types.
absl::StatusOr<HloInstruction *> TurnF8DotWithUnsupportedOutputTypeIntoF32(
HloInstruction *instr) {
Shape output_f32_shape = instr->shape();
output_f32_shape.set_element_type(F32);
HloInstruction *f32_dot =
instr->AddInstruction(instr->CloneWithNewShape(output_f32_shape));
HloInstruction *convert = instr->AddInstruction(
HloInstruction::CreateConvert(instr->shape(), f32_dot));
TF_RETURN_IF_ERROR(ReplaceInstruction(instr, convert));
return f32_dot;
}
#endif // TENSORFLOW_USE_ROCM

// Turns an F8 dot into an F16 dot, converting operands to F16 and
// converting the output back to F8.
absl::StatusOr<HloInstruction *> TurnF8DotIntoF16Dot(HloInstruction *instr) {
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1333,8 +1333,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
const GpuFloatSupport f8e5m2_support(gpu_version, F8E5M2, F16);
const GpuFloatSupport f8e4m3fn_support(gpu_version, F8E4M3FN, F16);
const FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ, F16);
const FloatSupport f8e5m2fnuz_support(F8E5M2FNUZ, F16);
const FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ, F16);
const GpuFloatSupport f8e5m2fnuz_support(gpu_version, F8E5M2FNUZ, F16);
const GpuFloatSupport f8e4m3fnuz_support(gpu_version, F8E4M3FNUZ, F16);
auto add_float_normalization = [&](HloPassPipeline& pipeline) {
auto& sub_pipeline =
pipeline.AddPass<HloPassPipeline>("float_normalization");
Expand Down
8 changes: 5 additions & 3 deletions xla/service/gpu/ir_emission_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,11 @@ bool IsMatrixMultiplication(const HloInstruction& dot) {
PrimitiveType output_primitive_type = dot.shape().element_type();
bool type_is_allowed =
(output_primitive_type == F8E4M3FN || output_primitive_type == F8E5M2 ||
output_primitive_type == F16 || output_primitive_type == BF16 ||
output_primitive_type == F32 || output_primitive_type == F64 ||
output_primitive_type == C64 || output_primitive_type == C128) ||
output_primitive_type == F8E4M3FNUZ ||
output_primitive_type == F8E5M2FNUZ || output_primitive_type == F16 ||
output_primitive_type == BF16 || output_primitive_type == F32 ||
output_primitive_type == F64 || output_primitive_type == C64 ||
output_primitive_type == C128) ||
(output_primitive_type == S32 && lhs_shape.element_type() == S8 &&
rhs_shape.element_type() == S8);
bool shapes_are_valid =
Expand Down
15 changes: 9 additions & 6 deletions xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -732,10 +732,6 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunk(
return absl::OkStatus();
}

#endif // GOOGLE_CUDA || TF_HIPBLASLT

#if GOOGLE_CUDA

absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8(
const HloCustomCallInstruction* instr) {
TF_RET_CHECK(instr->operand_count() == 6 || instr->operand_count() == 7 ||
Expand Down Expand Up @@ -771,12 +767,17 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8(
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice b_scale,
GetAllocationSliceForHlo(instr->operand(a_scale_index + 1)));
#if GOOGLE_CUDA
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice c_scale,
GetAllocationSliceForHlo(instr->operand(a_scale_index + 2)));
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice d_scale,
GetAllocationSliceForHlo(instr->operand(a_scale_index + 3)));
#else // TENSORFLOW_USE_ROCM
BufferAllocation::Slice c_scale;
BufferAllocation::Slice d_scale;
#endif

BufferAllocation::Slice bias;
if (has_vector_bias) {
Expand Down Expand Up @@ -810,7 +811,9 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8(
AddThunkToThunkSequence(std::move(thunk));
return absl::OkStatus();
}
#endif // GOOGLE_CUDA || TF_HIPBLASLT

#if GOOGLE_CUDA
absl::Status IrEmitterUnnested::EmitConvolutionReorderThunk(
const HloCustomCallInstruction* instr) {
bool has_bias = instr->operand_count() > 1;
Expand Down Expand Up @@ -2902,11 +2905,11 @@ absl::Status IrEmitterUnnested::EmitHloInstruction(
if (IsCublasLtMatmul(*instr)) {
return EmitCublasLtMatmulThunk(custom_call);
}
#endif // GOOGLE_CUDA || TF_HIPBLASLT
#if GOOGLE_CUDA
if (IsCublasLtMatmulF8(*instr)) {
return EmitCublasLtMatmulThunkF8(custom_call);
}
#endif // GOOGLE_CUDA || TF_HIPBLASLT
#if GOOGLE_CUDA
if (IsCudnnConvolutionReorder(*instr)) {
return EmitConvolutionReorderThunk(custom_call);
}
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/ir_emitter_unnested.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ class IrEmitterUnnested : public IrEmitter {
absl::Status EmitGemmThunk(const HloCustomCallInstruction* instr);
#if GOOGLE_CUDA || TF_HIPBLASLT
absl::Status EmitCublasLtMatmulThunk(const HloCustomCallInstruction* instr);
absl::Status EmitCublasLtMatmulThunkF8(const HloCustomCallInstruction* instr);
#endif // GOOGLE_CUDA || TF_HIPBLASLT
#if GOOGLE_CUDA
absl::Status EmitCublasLtMatmulThunkF8(const HloCustomCallInstruction* instr);
absl::Status EmitConvolutionReorderThunk(
const HloCustomCallInstruction* instr);
absl::Status EmitNormThunk(const HloCustomCallInstruction* instr);
Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/matmul_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,8 @@ absl::StatusOr<bool> CanFoldTransposeOperandIntoDot(const HloInstruction& dot,
switch (output_shape.element_type()) {
case F8E4M3FN:
case F8E5M2:
case F8E4M3FNUZ:
case F8E5M2FNUZ:
case F16:
case BF16:
case F32:
Expand Down
Loading

0 comments on commit 6abe18a

Please sign in to comment.