Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR #16841: Delete FP8 Scaling Factors in GEMM Rewriter #17731

Merged
merged 1 commit into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -740,8 +740,7 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunk(

absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8(
const HloCustomCallInstruction* instr) {
TF_RET_CHECK(instr->operand_count() == 6 || instr->operand_count() == 7 ||
instr->operand_count() == 8);
TF_RET_CHECK(instr->operand_count() > 3 && instr->operand_count() < 8);
TF_ASSIGN_OR_RETURN(const auto gpu_config,
instr->backend_config<xla::gpu::GpuBackendConfig>());
const xla::gpu::GemmBackendConfig& config = gpu_config.gemm_backend_config();
Expand Down Expand Up @@ -777,22 +776,22 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8(
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice b_scale,
GetAllocationSliceForHlo(instr->operand(a_scale_index + 1)));

// cublasLT requires c_scale/d_scale to be null when C/D is not FP8.
// Currently, C cannot be FP8.
BufferAllocation::Slice c_scale, d_scale;
#if GOOGLE_CUDA
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice c_scale,
GetAllocationSliceForHlo(instr->operand(a_scale_index + 2)));
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice d_scale,
GetAllocationSliceForHlo(instr->operand(a_scale_index + 3)));
#else // TENSORFLOW_USE_ROCM
BufferAllocation::Slice c_scale;
BufferAllocation::Slice d_scale;
if (instr->shape().tuple_shapes(0).element_type() == F8E4M3FN ||
instr->shape().tuple_shapes(0).element_type() == F8E5M2) {
TF_ASSIGN_OR_RETURN(d_scale,
GetAllocationSliceForHlo(instr->operands().back()));
}
#endif

BufferAllocation::Slice bias;
if (has_vector_bias) {
TF_ASSIGN_OR_RETURN(
bias, GetAllocationSliceForHlo(instr->operand(a_scale_index + 4)));
bias, GetAllocationSliceForHlo(instr->operand(a_scale_index + 2)));
}

BufferAllocation::Slice d_amax;
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gpu/matmul_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,8 @@ bool IsTf32Allowed(PrecisionConfig::Algorithm algorithm,
if (has_vector_bias) {
int vector_bias_index = has_matrix_bias ? 3 : 2;
if (primitive_util::IsF8Type(lhs_shape.element_type())) {
// FP8 gemms have 4 scales as inputs which come before the vector bias.
vector_bias_index += 4;
// FP8 gemms have 2 scales as inputs which come before the vector bias.
vector_bias_index += 2;
}
vector_bias_shape = gemm->operand(vector_bias_index)->shape();
}
Expand Down
27 changes: 18 additions & 9 deletions xla/service/gpu/transforms/gemm_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1083,12 +1083,18 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {

// cuBLASLt FP8 GEMM kernels require the scaling factors to be in F32
// format. Set the factors to one when no scaling factors were captured.
Literal one_literal = LiteralUtil::One(F32);
HloInstruction *one = instr->AddInstruction(
HloInstruction::CreateConstant(one_literal.Clone()));
std::array<bool, 2> mult_scale{a.mult_scale, b.mult_scale};
std::array<HloInstruction *, 2> scales{a.scale, b.scale}, inv_scales,
scales_f32;
HloInstruction *one_constant = nullptr;
auto one = [&one_constant, instr]() -> HloInstruction * {
if (!one_constant) {
one_constant = instr->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::One(F32)));
}
return one_constant;
};

for (int i = 0; i < scales.size(); ++i) {
if (scales[i]) {
if (!ShapeUtil::IsScalar(scales[i]->shape())) {
Expand All @@ -1099,15 +1105,15 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
}
if (!mult_scale[i]) {
inv_scales[i] = instr->AddInstruction(HloInstruction::CreateBinary(
scales[i]->shape(), HloOpcode::kDivide, one, scales[i]));
scales[i]->shape(), HloOpcode::kDivide, one(), scales[i]));
}
scales_f32[i] = mult_scale[i] ? scales[i] : inv_scales[i];
if (scales_f32[i]->shape().element_type() != F32) {
scales_f32[i] = instr->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::MakeScalarShape(F32), scales_f32[i]));
}
} else {
scales_f32[i] = one;
scales_f32[i] = one();
}
}

Expand Down Expand Up @@ -1249,7 +1255,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
PadShapeToMultipleOf16(instr->shape(), out_batch_dims);

std::vector<HloInstruction *> operands_list = {
a.fp8_input, b.fp8_input, scales_f32[0], scales_f32[1], one, one};
a.fp8_input, b.fp8_input, scales_f32[0], scales_f32[1]};

HloInstruction *new_custom_call =
instr->AddInstruction(HloInstruction::CreateCustomCall(
Expand Down Expand Up @@ -1415,13 +1421,16 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
}
}

// If necessary, invert the scaling factor of D and convert to F32.
// If necessary, invert the scaling factor of D and convert to F32. When no
// scaling factor was captured, set the factor to one.
if (d_scale) {
TF_ASSIGN_OR_RETURN(d_scale,
InvertAndConvertScalar(d_scale, !mult_scale));
TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWith(
gemm_backend_config.beta() == 0.0 ? 5 : 6, d_scale));
} else {
d_scale = instr->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::One(F32)));
}
existing_gemm->AppendOperand(d_scale);

// If present, elide the calculation of the maximum of the absolute values
// of the result of the GEMM.
Expand Down
Loading
Loading