Skip to content

Commit

Permalink
[ROCM] Initial support of fp8 Matmul via hipBlasLt.
Browse files Browse the repository at this point in the history
  • Loading branch information
wenchenvincent committed Feb 16, 2024
1 parent 2856cb2 commit cc0b48f
Show file tree
Hide file tree
Showing 12 changed files with 1,428 additions and 521 deletions.
122 changes: 117 additions & 5 deletions xla/service/gpu/gemm_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ bool IsSupportedF8Pattern(
}

std::reverse(subgraph.begin(), subgraph.end());

// When not operating directly on an FP8 operand, the second and
// third instructions in the subgraph must describe a dequantization, i.e. a
// convert instruction followed by a multiply/divide instruction.
Expand Down Expand Up @@ -533,6 +532,13 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
const_cast<HloInstruction *>(instr), b, b_scale,
b_mult_scale, b_ops);
})))) {
#if TENSORFLOW_USE_ROCM
if (instr->shape().element_type() == BF16 ||
instr->shape().element_type() == F8E4M3FNUZ ||
instr->shape().element_type() == F8E5M2FNUZ) {
TF_ASSIGN_OR_RETURN(instr, TurnF8DotWithUnsupportedOutputTypeIntoF32(instr));
}
#endif // TENSORFLOW_USE_ROCM
TF_ASSIGN_OR_RETURN(
bool created_call,
CreateF8CustomCall(instr, gpu_backend_config, a, b, a_scale, b_scale,
Expand Down Expand Up @@ -839,9 +845,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
HloInstruction *b_scale, bool a_mult_scale, bool b_mult_scale,
std::vector<std::pair<HloInstruction *, int>> a_ops,
std::vector<std::pair<HloInstruction *, int>> b_ops) {
#if GOOGLE_CUDA
GemmBackendConfig &gemm_backend_config =
*gpu_backend_config.mutable_gemm_backend_config();
#if GOOGLE_CUDA
auto cuda_compute_capability_ =
std::get<se::CudaComputeCapability>(gpu_version_);
// FP8 GEMM kernels are only available on Ada, Hopper, and later
Expand All @@ -851,17 +857,38 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
<< "FP8 Custom Calls require Ada, Hopper, or later architectures.";
return false;
}

#if CUDA_VERSION < 12000
// FP8 GEMM kernels are only available with CUDA 12.0 and above
VLOG(1) << "FP8 Custom Calls require CUDA 12.0 or newer.";
return false;
#endif // CUDA_VERSION < 12000

#endif // GOOGLE_CUDA


#if TENSORFLOW_USE_ROCM
auto isrocm = std::get_if<se::RocmComputeCapability>(&gpu_version_);
if (!isrocm->has_fp8_support()) {
VLOG(1)
<< "FP8 Custom Calls require MI300, or later architectures.";
return false;
}

#if TF_ROCM_VERSION < 60000
// FP8 GEMM kernels are only available with ROCm 6.0 and above
VLOG(1) << "FP8 Custom Calls require ROCm 6.0 or newer.";
return false;
#endif //TF_ROCM_VERSION < 60000

#endif // TENSORFLOW_USE_ROCM

PrimitiveType a_type = a->shape().element_type();
PrimitiveType b_type = b->shape().element_type();

// cuBLASLt FP8 GEMM kernels require one of the two operands to be in
// F8E4M3FN format.
#if GOOGLE_CUDA
if (a_type == F8E5M2 && b_type == F8E5M2) {
VLOG(1)
<< "Failed to rewrite " << instr->ToShortString()
Expand All @@ -878,6 +905,26 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
<< PrimitiveType_Name(b_type);
return false;
}
#endif // GOOGLE_CUDA

#if TENSORFLOW_USE_ROCM
if (a_type == F8E5M2FNUZ && b_type == F8E5M2FNUZ) {
VLOG(1)
<< "Failed to rewrite " << instr->ToShortString()
<< " into FP8 Custom Call. The element type of one of the operands "
"must be F8E4M3FNUZ.";
return false;
}
if ((a_type != F8E5M2FNUZ && a_type != F8E4M3FNUZ) ||
(b_type != F8E5M2FNUZ && b_type != F8E4M3FNUZ)) {
VLOG(1) << "Failed to rewrite " << instr->ToShortString()
<< " into FP8 Custom Call. The input types must be F8E5M2FNUZ or "
"F8E4M3FNUZ, but got "
<< PrimitiveType_Name(a_type) << " and "
<< PrimitiveType_Name(b_type);
return false;
}
#endif // TENSORFLOW_USE_ROCM

absl::Span<const int64_t> batch_dims =
gemm_backend_config.dot_dimension_numbers().rhs_batch_dimensions();
Expand Down Expand Up @@ -915,11 +962,14 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
switch (instr->shape().element_type()) {
case F8E4M3FN:
case F8E5M2:
case F8E4M3FNUZ:
case F8E5M2FNUZ:
case BF16:
case F16:
case F32:
break;
default:

VLOG(1) << "Failed to rewrite " << instr->ToShortString()
<< " into FP8 Custom Call. Output element type must be "
"F8E4M3FN, F8E5M2, BF16, F16 or F32. Actual element type is "
Expand Down Expand Up @@ -1068,9 +1118,6 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
ReplaceInstruction(instr, slice ? slice : new_custom_call));
VLOG(1) << instr->ToString() << " rewritten into FP8 Custom Call.";
return true;
#else // TENSORFLOW_USE_ROCM
return false;
#endif
}

absl::Status F8ConvertD(HloInstruction *instr, HloInstruction *existing_gemm,
Expand Down Expand Up @@ -1701,11 +1748,20 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
const PrimitiveType b_dtype = instr.operand(1)->shape().element_type();
const PrimitiveType output_type =
bias ? bias->shape().element_type() : instr.shape().element_type();
#if GOOGLE_CUDA
const std::array<PrimitiveType, 12> supported_type = {
PrimitiveType::F8E5M2, PrimitiveType::F8E4M3FN, PrimitiveType::S8,
PrimitiveType::F16, PrimitiveType::BF16, PrimitiveType::F32,
PrimitiveType::S32, PrimitiveType::F64, PrimitiveType::C64,
PrimitiveType::C128};
#endif // GOOGLE_CUDA
#if TENSORFLOW_USE_ROCM
const std::array<PrimitiveType, 12> supported_type = {
PrimitiveType::F8E5M2FNUZ, PrimitiveType::F8E4M3FNUZ, PrimitiveType::S8,
PrimitiveType::F16, PrimitiveType::BF16, PrimitiveType::F32,
PrimitiveType::S32, PrimitiveType::F64, PrimitiveType::C64,
PrimitiveType::C128};
#endif // TENSORFLOW_USE_ROCM
if (!absl::c_linear_search(supported_type, output_type)) return false;
// cublasLt has a defined set of combinations of types that it supports.
// Figure out the computeType and scaleType.
Expand Down Expand Up @@ -1764,6 +1820,45 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
PrimitiveType::F8E4M3FN, DataType::kFloat},
#endif // GOOGLE_CUDA
#if TENSORFLOW_USE_ROCM
// FP8 types:
/*
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
PrimitiveType::F8E4M3FNUZ, DataType::kBF16},
*/
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
PrimitiveType::F8E4M3FNUZ, DataType::kF8E4M3FNUZ},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
PrimitiveType::F8E4M3FNUZ, DataType::kHalf},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
PrimitiveType::F8E4M3FNUZ, DataType::kFloat},

/*
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
PrimitiveType::F8E5M2, DataType::kBF16},
*/
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
PrimitiveType::F8E5M2, DataType::kF8E4M3FNUZ},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
PrimitiveType::F8E5M2, DataType::kF8E5M2},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
PrimitiveType::F8E5M2, DataType::kHalf},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
PrimitiveType::F8E5M2, DataType::kFloat},

/*
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
PrimitiveType::F8E4M3FNUZ, DataType::kBF16},
*/
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
PrimitiveType::F8E4M3FNUZ, DataType::kF8E4M3FNUZ},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
PrimitiveType::F8E4M3FNUZ, DataType::kF8E5M2},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
PrimitiveType::F8E4M3FNUZ, DataType::kHalf},
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
PrimitiveType::F8E4M3FNUZ, DataType::kFloat},
#endif // TENSORFLOW_USE_ROCM
// Other data types:
{ComputationType::kF16, DataType::kHalf, PrimitiveType::F16,
PrimitiveType::F16, DataType::kHalf},
Expand Down Expand Up @@ -1946,6 +2041,23 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
return lhs_non_contracting_dimension_size <= kMaxDimensionSize;
}


#if TENSORFLOW_USE_ROCM
// Turns an F8 dot with output type BF16 or F8 into an F8 dot with F32 output,
// and converting the F32 output to BF16 or F8.
absl::StatusOr<HloInstruction *> TurnF8DotWithUnsupportedOutputTypeIntoF32(
HloInstruction *instr) {
Shape output_f32_shape = instr->shape();
output_f32_shape.set_element_type(F32);
HloInstruction *f32_dot =
instr->AddInstruction(instr->CloneWithNewShape(output_f32_shape));
HloInstruction *convert = instr->AddInstruction(
HloInstruction::CreateConvert(instr->shape(), f32_dot));
TF_RETURN_IF_ERROR(ReplaceInstruction(instr, convert));
return f32_dot;
}
#endif // TENSORFLOW_USE_ROCM

// Turns an F8 dot into an F16 dot, converting operands to F16 and
// converting the output back to F8.
absl::StatusOr<HloInstruction *> TurnF8DotIntoF16Dot(HloInstruction *instr) {
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1236,8 +1236,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
const GpuFloatSupport f8e5m2_support(F8E5M2, F16);
const GpuFloatSupport f8e4m3fn_support(F8E4M3FN, F16);
const FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ, F16);
const FloatSupport f8e5m2fnuz_support(F8E5M2FNUZ, F16);
const FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ, F16);
const GpuFloatSupport f8e5m2fnuz_support(F8E5M2FNUZ, F16);
const GpuFloatSupport f8e4m3fnuz_support(F8E4M3FNUZ, F16);
auto add_float_normalization = [&](HloPassPipeline& pipeline) {
auto& sub_pipeline =
pipeline.AddPass<HloPassPipeline>("float_normalization");
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/ir_emission_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ bool IsMatrixMultiplication(const HloInstruction& dot) {
PrimitiveType output_primitive_type = dot.shape().element_type();
bool type_is_allowed =
(output_primitive_type == F8E4M3FN || output_primitive_type == F8E5M2 ||
output_primitive_type == F8E4M3FNUZ || output_primitive_type == F8E5M2FNUZ ||
output_primitive_type == F16 || output_primitive_type == BF16 ||
output_primitive_type == F32 || output_primitive_type == F64 ||
output_primitive_type == C64 || output_primitive_type == C128) ||
Expand Down
Loading

0 comments on commit cc0b48f

Please sign in to comment.