diff --git a/xla/service/gpu/ir_emitter_unnested.cc b/xla/service/gpu/ir_emitter_unnested.cc index b7da6271befc9..ea0577e05077f 100644 --- a/xla/service/gpu/ir_emitter_unnested.cc +++ b/xla/service/gpu/ir_emitter_unnested.cc @@ -740,8 +740,7 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunk( absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8( const HloCustomCallInstruction* instr) { - TF_RET_CHECK(instr->operand_count() == 6 || instr->operand_count() == 7 || - instr->operand_count() == 8); + TF_RET_CHECK(instr->operand_count() > 3 && instr->operand_count() < 8); TF_ASSIGN_OR_RETURN(const auto gpu_config, instr->backend_config()); const xla::gpu::GemmBackendConfig& config = gpu_config.gemm_backend_config(); @@ -777,22 +776,22 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8( TF_ASSIGN_OR_RETURN( BufferAllocation::Slice b_scale, GetAllocationSliceForHlo(instr->operand(a_scale_index + 1))); + + // cublasLT requires c_scale/d_scale to be null when C/D is not FP8. + // Currently, C cannot be FP8. + BufferAllocation::Slice c_scale, d_scale; #if GOOGLE_CUDA - TF_ASSIGN_OR_RETURN( - BufferAllocation::Slice c_scale, - GetAllocationSliceForHlo(instr->operand(a_scale_index + 2))); - TF_ASSIGN_OR_RETURN( - BufferAllocation::Slice d_scale, - GetAllocationSliceForHlo(instr->operand(a_scale_index + 3))); -#else // TENSORFLOW_USE_ROCM - BufferAllocation::Slice c_scale; - BufferAllocation::Slice d_scale; + if (instr->shape().tuple_shapes(0).element_type() == F8E4M3FN || + instr->shape().tuple_shapes(0).element_type() == F8E5M2) { + TF_ASSIGN_OR_RETURN(d_scale, + GetAllocationSliceForHlo(instr->operands().back())); + } #endif BufferAllocation::Slice bias; if (has_vector_bias) { TF_ASSIGN_OR_RETURN( - bias, GetAllocationSliceForHlo(instr->operand(a_scale_index + 4))); + bias, GetAllocationSliceForHlo(instr->operand(a_scale_index + 2))); } BufferAllocation::Slice d_amax; diff --git a/xla/service/gpu/matmul_utils.cc b/xla/service/gpu/matmul_utils.cc index 49270de65ecd3..401832e2d17e0 100644 --- a/xla/service/gpu/matmul_utils.cc +++ b/xla/service/gpu/matmul_utils.cc @@ -484,8 +484,8 @@ bool IsTf32Allowed(PrecisionConfig::Algorithm algorithm, if (has_vector_bias) { int vector_bias_index = has_matrix_bias ? 3 : 2; if (primitive_util::IsF8Type(lhs_shape.element_type())) { - // FP8 gemms have 4 scales as inputs which come before the vector bias. - vector_bias_index += 4; + // FP8 gemms have 2 scales as inputs which come before the vector bias. + vector_bias_index += 2; } vector_bias_shape = gemm->operand(vector_bias_index)->shape(); } diff --git a/xla/service/gpu/transforms/gemm_rewriter.cc b/xla/service/gpu/transforms/gemm_rewriter.cc index dea1f704c5801..d7674efc15e94 100644 --- a/xla/service/gpu/transforms/gemm_rewriter.cc +++ b/xla/service/gpu/transforms/gemm_rewriter.cc @@ -1083,12 +1083,18 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // cuBLASLt FP8 GEMM kernels require the scaling factors to be in F32 // format. Set the factors to one when no scaling factors were captured. - Literal one_literal = LiteralUtil::One(F32); - HloInstruction *one = instr->AddInstruction( - HloInstruction::CreateConstant(one_literal.Clone())); std::array mult_scale{a.mult_scale, b.mult_scale}; std::array scales{a.scale, b.scale}, inv_scales, scales_f32; + HloInstruction *one_constant = nullptr; + auto one = [&one_constant, instr]() -> HloInstruction * { + if (!one_constant) { + one_constant = instr->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::One(F32))); + } + return one_constant; + }; + for (int i = 0; i < scales.size(); ++i) { if (scales[i]) { if (!ShapeUtil::IsScalar(scales[i]->shape())) { @@ -1099,7 +1105,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } if (!mult_scale[i]) { inv_scales[i] = instr->AddInstruction(HloInstruction::CreateBinary( - scales[i]->shape(), HloOpcode::kDivide, one, scales[i])); + scales[i]->shape(), HloOpcode::kDivide, one(), scales[i])); } scales_f32[i] = mult_scale[i] ? scales[i] : inv_scales[i]; if (scales_f32[i]->shape().element_type() != F32) { @@ -1107,7 +1113,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { ShapeUtil::MakeScalarShape(F32), scales_f32[i])); } } else { - scales_f32[i] = one; + scales_f32[i] = one(); } } @@ -1249,7 +1255,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { PadShapeToMultipleOf16(instr->shape(), out_batch_dims); std::vector operands_list = { - a.fp8_input, b.fp8_input, scales_f32[0], scales_f32[1], one, one}; + a.fp8_input, b.fp8_input, scales_f32[0], scales_f32[1]}; HloInstruction *new_custom_call = instr->AddInstruction(HloInstruction::CreateCustomCall( @@ -1415,13 +1421,16 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } } - // If necessary, invert the scaling factor of D and convert to F32. + // If necessary, invert the scaling factor of D and convert to F32. When no + // scaling factor was captured, set the factor to one. if (d_scale) { TF_ASSIGN_OR_RETURN(d_scale, InvertAndConvertScalar(d_scale, !mult_scale)); - TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWith( - gemm_backend_config.beta() == 0.0 ? 5 : 6, d_scale)); + } else { + d_scale = instr->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::One(F32))); } + existing_gemm->AppendOperand(d_scale); // If present, elide the calculation of the maximum of the absolute values // of the result of the GEMM. diff --git a/xla/service/gpu/transforms/gemm_rewriter_test.cc b/xla/service/gpu/transforms/gemm_rewriter_test.cc index 3df393a0c89d7..140787413d0f6 100644 --- a/xla/service/gpu/transforms/gemm_rewriter_test.cc +++ b/xla/service/gpu/transforms/gemm_rewriter_test.cc @@ -4950,11 +4950,11 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDF8) { )"; if (IsRocm() && GetToolkitVersion() < se::SemanticVersion{6, 2, 0}) { checks.append( - R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), + R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]]), )"); } else { checks.append( - R"(; CHECK-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), + R"(; CHECK-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]]), )"); } checks.append( @@ -5009,7 +5009,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDMatrixBiasF8) { ; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[C1:[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5064,8 +5064,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5121,8 +5120,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDPaddedF8) { ; CHECK-NEXT: [[P1_TRANSPOSE_PADDED:%[^ ]+]] = <>[32,32]{1,0} pad([[P1_TRANSPOSE]], [[C1]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C4:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (f32[16,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_TRANSPOSE_PADDED]], [[P2]], [[P3]], [[C4]], /*index=5*/[[C4]]), +; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (f32[16,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_TRANSPOSE_PADDED]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5205,7 +5203,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDWithConvertF8) { ; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5269,8 +5267,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDUnaryOpsF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_U3]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C2]], /*index=5*/[[C2]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_U3]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5328,7 +5325,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_U3]], [[P1_TRANSPOSE]], [[C2]], [[C2]], [[C2]], /*index=5*/[[C2]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_U3]], [[P1_TRANSPOSE]], [[C2]], [[C2]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5389,8 +5386,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDDynamicSliceF8) { ; CHECK-NEXT: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[DYN_SLICE]], [[P1]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[DYN_SLICE]], [[P1]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5456,8 +5452,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDSelectF8) { ; CHECK-NEXT: [[SELECT:%[^ ]+]] = <>[16,32]{1,0} select([[P4]], [[P1]], [[C0_CONVERT]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[SELECT]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[SELECT]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5541,8 +5536,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, BatchedScaledABUnscaledDF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[10,16,32]{2,1,0} transpose([[P1]]), dimensions={0,2,1} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[10,16,16]{2,1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[10,16,16]{2,1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5598,8 +5592,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABAlphaDF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":3 @@ -5655,8 +5648,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDReluActivationF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5731,15 +5723,14 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[XS:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = bf16[] parameter(3) ; CHECK-NEXT: [[XS1:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) )"; if (IsRocm() && GetToolkitVersion() < se::SemanticVersion{6, 2, 0}) { checks += - R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]), + R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]]), )"; } else { checks += R"(; CHECK-NEXT: [[B:%[^ ]+]] = bf16[16]{0} parameter(4) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (bf16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]], [[B]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (bf16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[B]]), )"; } checks += R"(; CHECK: custom_call_target="__cublas$lt$matmul$f8", @@ -5831,15 +5822,14 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[XS:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = bf16[] parameter(3) ; CHECK-NEXT: [[XS1:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) )"; if (IsRocm() && GetToolkitVersion() < se::SemanticVersion{6, 2, 0}) { checks += - R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]), + R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]]), )"; } else { checks += - R"(; CHECK-NEXT: [[OUT:%[^ ]+]] = (bf16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]), + R"(; CHECK-NEXT: [[OUT:%[^ ]+]] = (bf16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]]), )"; } checks += R"(; CHECK: custom_call_target="__cublas$lt$matmul$f8", @@ -5943,8 +5933,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasF8) { ; CHECK: [[C0:%[^ ]+]] = f32[16,16]{1,0} add({{.*}}) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[P2]], [[P3]], /*index=5*/[[C1]], [[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: output_to_operand_aliasing={ ; CHECK-SAME: {0}: (2, {}) @@ -6009,8 +5998,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasPaddedF8) { ; CHECK-NEXT: [[P2_PADDED:%[^ ]+]] = f32[16,16]{1,0} pad([[P2]], [[C2]]), padding=0_2x0_2 ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[C3:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_TRANSPOSE_PADDED]], [[P2_PADDED]], [[P3]], [[P4]], /*index=5*/[[C3]], [[C3]]), +; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_TRANSPOSE_PADDED]], [[P2_PADDED]], [[P3]], [[P4]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6067,8 +6055,9 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledDF8) { ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P2_INV:%[^ ]+]] = f32[] divide([[C0]], [[P2]]) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) +; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C2]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C2]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6117,7 +6106,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DF8) { ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P2_INV:%[^ ]+]] = f32[] divide([[C0]], [[P2]]) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6164,7 +6153,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABInvScaledF32DF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[C0]], [[C0]], /*index=5*/[[C0]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[C0]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6215,7 +6204,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DMatrixBiasF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[16,16]{1,0} parameter(2) ; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[C0]], [[C0]], /*index=5*/[[C0]], [[C0]]), +; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[C0]], [[C0]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6280,12 +6269,12 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) -; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[P4_INV]]), +; CHECK-GCN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6390,12 +6379,12 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) -; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[P4_INV]]), +; CHECK-CGN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6472,11 +6461,10 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasWithDAmaxF8) { ; CHECK: [[C0:%[^ ]+]] = f16[16,16]{1,0} add({{.*}}) ; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3) ; CHECK: [[P3:%[^ ]+]] = f16[] parameter(4) -; CHECK: [[C1:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX: [[P4:%[^ ]+]] = f16[] parameter(5) -; CHECK-PTX: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[C1]], [[DUMMY2:%[^ ]+]]), +; CHECK-PTX: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[DUMMY2:%[^ ]+]]), ; CHECK-NOT: output_to_operand_aliasing -; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[C1]], [[DUMMY2:%[^ ]+]]), +; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[DUMMY2:%[^ ]+]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6543,14 +6531,14 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDVectorBiasF8) { ; CHECK-NEXT: [[CV:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) ; CHECK-NEXT: [[CV1:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) +; CHECK-NEXT: [[VB:%[^ ]+]] = f16[16]{0} parameter(2) ; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f16[] constant(1) ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f16[] parameter(5) ; CHECK-PTX-NEXT: [[DV:%[^ ]+]] = f16[] divide([[C2]], [[P4]]) ; CHECK-PTX-NEXT: [[CV2:%[^ ]+]] = f32[] convert([[DV]]) -; CHECK-NEXT: [[VB:%[^ ]+]] = f16[16]{0} parameter(2) -; CHECK-PTX: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[CV2]], [[VB]]), -; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]), +; CHECK-PTX: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[VB]], /*index=5*/[[CV2]]), +; CHECK-GCN: [[C:%[^ ]+]] = f32[] constant(1) +; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6607,10 +6595,9 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF32VectorBiasF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) ; CHECK-NEXT: [[VB:%[^ ]+]] = f32[16]{0} parameter(2) ; CHECK-NEXT: [[VBC:%[^ ]+]] = bf16[16]{0} convert([[VB]]) -; CHECK: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]], [[VBC]]), +; CHECK: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[VBC]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6670,9 +6657,8 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[CV:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) ; CHECK-NEXT: [[CV1:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) ; CHECK-NEXT: [[VB:%[^ ]+]] = f16[16]{0} parameter(2) -; CHECK : ROOT [[OUT:%[^ ]+]] = f16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]), +; CHECK : ROOT [[OUT:%[^ ]+]] = f16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[VB]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6744,10 +6730,9 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { ; CHECK-NEXT: [[P2_CV:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) ; CHECK-NEXT: [[P3_CV:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) ; CHECK-NEXT: [[B:%[^ ]+]] = f32[32]{0} parameter(2) ; CHECK-NEXT: [[B_F16:%[^ ]+]] = f16[32]{0} convert([[B]]) -; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f16[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[P2_CV]], [[P3_CV]], [[C]], /*index=5*/[[C]], [[B_F16]]), +; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f16[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[P2_CV]], [[P3_CV]], [[B_F16]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6828,12 +6813,11 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[P2_CV:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) ; CHECK-NEXT: [[P3_CV:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) ; CHECK-NEXT: [[B:%[^ ]+]] = f32[31]{0} parameter(2) ; CHECK-NEXT: [[B_F16:%[^ ]+]] = f16[31]{0} convert([[B]]) ; CHECK-NEXT: [[C3:%[^ ]+]] = f16[] constant(0) ; CHECK-NEXT: [[P2_PAD:%[^ ]+]] = f16[32]{0} pad([[B_F16]], [[C3]]), padding=0_1 -; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f16[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PAD]], [[P1_PAD]], [[P2_CV]], [[P3_CV]], [[C]], /*index=5*/[[C]], [[P2_PAD]]), +; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f16[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PAD]], [[P1_PAD]], [[P2_CV]], [[P3_CV]], [[P2_PAD]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6906,8 +6890,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) { ; CHECK-NEXT: [[B_BITCAST:%[^ ]+]] = f32[64,32]{1,0} bitcast([[B]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[B_BITCAST]], [[P2]], [[P3]], /*index=5*/[[C]], [[C]]), +; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[B_BITCAST]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6988,8 +6971,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[P2_PADDED:%[^ ]+]] = f32[48,32]{1,0} pad([[B_BITCAST]], [[C3]]), padding=0_3x0_1 ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[48,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2_PADDED]], [[P2]], [[P3]], /*index=5*/[[C]], [[C]]), +; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[48,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2_PADDED]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7054,8 +7036,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[32,16]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[48,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]), +; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[48,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7117,8 +7098,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllGatherF8) { ; CHECK: [[P1_TRANSPOSE:%[^ ]+]] = <>[32,64]{1,0} transpose([[AG1]]), dimensions={1,0} ; CHECK: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK: [[GEMM_TUPLE:%[^ ]+]] = (f32[16,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AG]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]), +; CHECK: [[GEMM_TUPLE:%[^ ]+]] = (f32[16,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AG]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7175,8 +7155,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllToAllF8) { ; CHECK: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) ; CHECK: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK: [[GEMM:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AA]], [[P1]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]), +; CHECK: [[GEMM:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AA]], [[P1]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7233,8 +7212,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) ; CHECK: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK: [[GEMM:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AA]], [[P1]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]), +; CHECK: [[GEMM:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AA]], [[P1]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7296,8 +7274,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[CV0:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(5) ; CHECK-NEXT: [[CV1:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK: [[GEMMOUT_TUPLE:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[MB]], [[CV0]], [[CV1]], /*index=5*/[[C1]], [[C1]]), +; CHECK: [[GEMMOUT_TUPLE:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[MB]], [[CV0]], [[CV1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7372,12 +7349,12 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) -; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[P4_INV]]), +; CHECK-GCN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7453,13 +7430,13 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[P2_CONVERT:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(3) ; CHECK-NEXT: [[P3_CONVERT:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f16[] constant(1) ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f16[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f16[] divide([[C2]], [[P4]]) ; CHECK-PTX-NEXT: [[P4_INV_CONVERT:%[^ ]+]] = f32[] convert([[P4_INV]]) -; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]], /*index=5*/[[P4_INV_CONVERT]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]], /*index=5*/[[C1]]), +; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[P4_INV_CONVERT]]), +; CHECK-CGN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7533,12 +7510,12 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) -; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[P4_INV]]), +; CHECK-CGN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 diff --git a/xla/stream_executor/cuda/cuda_blas_lt.cc b/xla/stream_executor/cuda/cuda_blas_lt.cc index fbbcfad52fb3a..c1ba88f3d61bf 100644 --- a/xla/stream_executor/cuda/cuda_blas_lt.cc +++ b/xla/stream_executor/cuda/cuda_blas_lt.cc @@ -449,15 +449,12 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, b_scale.opaque())); } - auto isF8Input = [](const auto& desc) { - return desc.type() == CUDA_R_8F_E4M3 || desc.type() == CUDA_R_8F_E5M2; - }; - if (c_scale != nullptr && isF8Input(c_desc_)) { + if (c_scale != nullptr) { TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_C_SCALE_POINTER, c_scale.opaque())); } - if (d_scale != nullptr && isF8Input(d_desc_)) { + if (d_scale != nullptr) { TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, d_scale.opaque()));