diff --git a/xla/service/gpu/gemm_rewriter.cc b/xla/service/gpu/gemm_rewriter.cc index 8665c5a3117614..d1ed3ae226294f 100644 --- a/xla/service/gpu/gemm_rewriter.cc +++ b/xla/service/gpu/gemm_rewriter.cc @@ -228,7 +228,6 @@ bool IsSupportedF8Pattern( } std::reverse(subgraph.begin(), subgraph.end()); - // When not operating directly on an FP8 operand, the second and // third instructions in the subgraph must describe a dequantization, i.e. a // convert instruction followed by a multiply/divide instruction. @@ -569,6 +568,13 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { const_cast(instr), b, b_scale, b_mult_scale, b_ops); })))) { +#if TENSORFLOW_USE_ROCM + if (instr->shape().element_type() != F16 && + instr->shape().element_type() != F32) { + TF_ASSIGN_OR_RETURN(instr, + TurnF8DotWithUnsupportedOutputTypeIntoF32(instr)); + } +#endif // TENSORFLOW_USE_ROCM TF_ASSIGN_OR_RETURN(bool created_call, CreateF8CustomCall(instr, gpu_backend_config, a, b, a_scale, b_scale, a_mult_scale, @@ -875,9 +881,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { HloInstruction *b_scale, bool a_mult_scale, bool b_mult_scale, std::vector> a_ops, std::vector> b_ops) { -#if GOOGLE_CUDA GemmBackendConfig &gemm_backend_config = *gpu_backend_config.mutable_gemm_backend_config(); +#if GOOGLE_CUDA auto cuda_compute_capability_ = std::get(gpu_version_); // FP8 GEMM kernels are only available on Ada, Hopper, and later @@ -887,17 +893,36 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { << "FP8 Custom Calls require Ada, Hopper, or later architectures."; return false; } + #if CUDA_VERSION < 12000 // FP8 GEMM kernels are only available with CUDA 12.0 and above VLOG(1) << "FP8 Custom Calls require CUDA 12.0 or newer."; return false; #endif // CUDA_VERSION < 12000 +#endif // GOOGLE_CUDA + +#if TENSORFLOW_USE_ROCM + auto isrocm = std::get_if(&gpu_version_); + if (!isrocm->has_fp8_support()) { + VLOG(1) << "FP8 Custom Calls require MI300, or later architectures."; + return false; + } + +#if TF_ROCM_VERSION < 60000 + // FP8 GEMM kernels are only available with ROCm 6.0 and above + VLOG(1) << "FP8 Custom Calls require ROCm 6.0 or newer."; + return false; +#endif // TF_ROCM_VERSION < 60000 + +#endif // TENSORFLOW_USE_ROCM + PrimitiveType a_type = a->shape().element_type(); PrimitiveType b_type = b->shape().element_type(); // cuBLASLt FP8 GEMM kernels require one of the two operands to be in // F8E4M3FN format. +#if GOOGLE_CUDA if (a_type == F8E5M2 && b_type == F8E5M2) { VLOG(1) << "Failed to rewrite " << instr->ToShortString() @@ -914,6 +939,26 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { << PrimitiveType_Name(b_type); return false; } +#endif // GOOGLE_CUDA + +#if TENSORFLOW_USE_ROCM + if (a_type == F8E5M2FNUZ && b_type == F8E5M2FNUZ) { + VLOG(1) + << "Failed to rewrite " << instr->ToShortString() + << " into FP8 Custom Call. The element type of one of the operands " + "must be F8E4M3FNUZ."; + return false; + } + if ((a_type != F8E5M2FNUZ && a_type != F8E4M3FNUZ) || + (b_type != F8E5M2FNUZ && b_type != F8E4M3FNUZ)) { + VLOG(1) << "Failed to rewrite " << instr->ToShortString() + << " into FP8 Custom Call. The input types must be F8E5M2FNUZ or " + "F8E4M3FNUZ, but got " + << PrimitiveType_Name(a_type) << " and " + << PrimitiveType_Name(b_type); + return false; + } +#endif // TENSORFLOW_USE_ROCM absl::Span batch_dims = gemm_backend_config.dot_dimension_numbers().rhs_batch_dimensions(); @@ -956,6 +1001,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { case F32: break; default: + VLOG(1) << "Failed to rewrite " << instr->ToShortString() << " into FP8 Custom Call. Output element type must be " "F8E4M3FN, F8E5M2, BF16, F16 or F32. Actual element type is " @@ -1104,9 +1150,6 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { ReplaceInstruction(instr, slice ? slice : new_custom_call)); VLOG(1) << instr->ToString() << " rewritten into FP8 Custom Call."; return true; -#else // TENSORFLOW_USE_ROCM - return false; -#endif } absl::Status F8ConvertD(HloInstruction *instr, HloInstruction *existing_gemm, @@ -1741,10 +1784,12 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { const PrimitiveType output_type = bias ? bias->shape().element_type() : instr.shape().element_type(); const std::array supported_type = { - PrimitiveType::F8E5M2, PrimitiveType::F8E4M3FN, PrimitiveType::S8, - PrimitiveType::F16, PrimitiveType::BF16, PrimitiveType::F32, - PrimitiveType::S32, PrimitiveType::F64, PrimitiveType::C64, - PrimitiveType::C128}; + PrimitiveType::F8E5M2FNUZ, PrimitiveType::F8E4M3FNUZ, + PrimitiveType::F8E5M2, PrimitiveType::F8E4M3FN, + PrimitiveType::S8, PrimitiveType::F16, + PrimitiveType::BF16, PrimitiveType::F32, + PrimitiveType::S32, PrimitiveType::F64, + PrimitiveType::C64, PrimitiveType::C128}; if (!absl::c_linear_search(supported_type, output_type)) return false; // cublasLt has a defined set of combinations of types that it supports. // Figure out the computeType and scaleType. @@ -1807,6 +1852,39 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2, PrimitiveType::F8E4M3FN, DataType::kFloat}, #endif // GOOGLE_CUDA +#if TENSORFLOW_USE_ROCM + // FP8 types: + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ, + PrimitiveType::F8E4M3FNUZ, DataType::kBF16}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ, + PrimitiveType::F8E4M3FNUZ, DataType::kF8E4M3FNUZ}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ, + PrimitiveType::F8E4M3FNUZ, DataType::kHalf}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ, + PrimitiveType::F8E4M3FNUZ, DataType::kFloat}, + + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ, + PrimitiveType::F8E5M2FNUZ, DataType::kBF16}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ, + PrimitiveType::F8E5M2FNUZ, DataType::kF8E4M3FNUZ}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ, + PrimitiveType::F8E5M2FNUZ, DataType::kF8E5M2FNUZ}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ, + PrimitiveType::F8E5M2FNUZ, DataType::kHalf}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ, + PrimitiveType::F8E5M2FNUZ, DataType::kFloat}, + + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ, + PrimitiveType::F8E4M3FNUZ, DataType::kBF16}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ, + PrimitiveType::F8E4M3FNUZ, DataType::kF8E4M3FNUZ}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ, + PrimitiveType::F8E4M3FNUZ, DataType::kF8E5M2FNUZ}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ, + PrimitiveType::F8E4M3FNUZ, DataType::kHalf}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ, + PrimitiveType::F8E4M3FNUZ, DataType::kFloat}, +#endif // TENSORFLOW_USE_ROCM // Other data types: {ComputationType::kF16, DataType::kHalf, PrimitiveType::F16, PrimitiveType::F16, DataType::kHalf}, @@ -1993,6 +2071,22 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return lhs_non_contracting_dimension_size <= kMaxDimensionSize; } +#if TENSORFLOW_USE_ROCM + // Turns an F8 dot with unsupported output type into an F8 dot with F32 + // output, and converting the F32 output to unsupported output types. + absl::StatusOr TurnF8DotWithUnsupportedOutputTypeIntoF32( + HloInstruction *instr) { + Shape output_f32_shape = instr->shape(); + output_f32_shape.set_element_type(F32); + HloInstruction *f32_dot = + instr->AddInstruction(instr->CloneWithNewShape(output_f32_shape)); + HloInstruction *convert = instr->AddInstruction( + HloInstruction::CreateConvert(instr->shape(), f32_dot)); + TF_RETURN_IF_ERROR(ReplaceInstruction(instr, convert)); + return f32_dot; + } +#endif // TENSORFLOW_USE_ROCM + // Turns an F8 dot into an F16 dot, converting operands to F16 and // converting the output back to F8. absl::StatusOr TurnF8DotIntoF16Dot(HloInstruction *instr) { diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 167ffc8d2d4b41..1e59be42f3b476 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1333,8 +1333,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( const GpuFloatSupport f8e5m2_support(gpu_version, F8E5M2, F16); const GpuFloatSupport f8e4m3fn_support(gpu_version, F8E4M3FN, F16); const FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ, F16); - const FloatSupport f8e5m2fnuz_support(F8E5M2FNUZ, F16); - const FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ, F16); + const GpuFloatSupport f8e5m2fnuz_support(gpu_version, F8E5M2FNUZ, F16); + const GpuFloatSupport f8e4m3fnuz_support(gpu_version, F8E4M3FNUZ, F16); auto add_float_normalization = [&](HloPassPipeline& pipeline) { auto& sub_pipeline = pipeline.AddPass("float_normalization"); diff --git a/xla/service/gpu/ir_emission_utils.cc b/xla/service/gpu/ir_emission_utils.cc index dc409697b4c865..5c274acd80a906 100644 --- a/xla/service/gpu/ir_emission_utils.cc +++ b/xla/service/gpu/ir_emission_utils.cc @@ -123,9 +123,11 @@ bool IsMatrixMultiplication(const HloInstruction& dot) { PrimitiveType output_primitive_type = dot.shape().element_type(); bool type_is_allowed = (output_primitive_type == F8E4M3FN || output_primitive_type == F8E5M2 || - output_primitive_type == F16 || output_primitive_type == BF16 || - output_primitive_type == F32 || output_primitive_type == F64 || - output_primitive_type == C64 || output_primitive_type == C128) || + output_primitive_type == F8E4M3FNUZ || + output_primitive_type == F8E5M2FNUZ || output_primitive_type == F16 || + output_primitive_type == BF16 || output_primitive_type == F32 || + output_primitive_type == F64 || output_primitive_type == C64 || + output_primitive_type == C128) || (output_primitive_type == S32 && lhs_shape.element_type() == S8 && rhs_shape.element_type() == S8); bool shapes_are_valid = diff --git a/xla/service/gpu/ir_emitter_unnested.cc b/xla/service/gpu/ir_emitter_unnested.cc index be925725167418..b55f080effba4d 100644 --- a/xla/service/gpu/ir_emitter_unnested.cc +++ b/xla/service/gpu/ir_emitter_unnested.cc @@ -732,10 +732,6 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunk( return absl::OkStatus(); } -#endif // GOOGLE_CUDA || TF_HIPBLASLT - -#if GOOGLE_CUDA - absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8( const HloCustomCallInstruction* instr) { TF_RET_CHECK(instr->operand_count() == 6 || instr->operand_count() == 7 || @@ -771,12 +767,17 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8( TF_ASSIGN_OR_RETURN( BufferAllocation::Slice b_scale, GetAllocationSliceForHlo(instr->operand(a_scale_index + 1))); +#if GOOGLE_CUDA TF_ASSIGN_OR_RETURN( BufferAllocation::Slice c_scale, GetAllocationSliceForHlo(instr->operand(a_scale_index + 2))); TF_ASSIGN_OR_RETURN( BufferAllocation::Slice d_scale, GetAllocationSliceForHlo(instr->operand(a_scale_index + 3))); +#else // TENSORFLOW_USE_ROCM + BufferAllocation::Slice c_scale; + BufferAllocation::Slice d_scale; +#endif BufferAllocation::Slice bias; if (has_vector_bias) { @@ -810,7 +811,9 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8( AddThunkToThunkSequence(std::move(thunk)); return absl::OkStatus(); } +#endif // GOOGLE_CUDA || TF_HIPBLASLT +#if GOOGLE_CUDA absl::Status IrEmitterUnnested::EmitConvolutionReorderThunk( const HloCustomCallInstruction* instr) { bool has_bias = instr->operand_count() > 1; @@ -2902,11 +2905,11 @@ absl::Status IrEmitterUnnested::EmitHloInstruction( if (IsCublasLtMatmul(*instr)) { return EmitCublasLtMatmulThunk(custom_call); } -#endif // GOOGLE_CUDA || TF_HIPBLASLT -#if GOOGLE_CUDA if (IsCublasLtMatmulF8(*instr)) { return EmitCublasLtMatmulThunkF8(custom_call); } +#endif // GOOGLE_CUDA || TF_HIPBLASLT +#if GOOGLE_CUDA if (IsCudnnConvolutionReorder(*instr)) { return EmitConvolutionReorderThunk(custom_call); } diff --git a/xla/service/gpu/ir_emitter_unnested.h b/xla/service/gpu/ir_emitter_unnested.h index a6e250a56e75b1..739ebb7e87fd58 100644 --- a/xla/service/gpu/ir_emitter_unnested.h +++ b/xla/service/gpu/ir_emitter_unnested.h @@ -142,9 +142,9 @@ class IrEmitterUnnested : public IrEmitter { absl::Status EmitGemmThunk(const HloCustomCallInstruction* instr); #if GOOGLE_CUDA || TF_HIPBLASLT absl::Status EmitCublasLtMatmulThunk(const HloCustomCallInstruction* instr); + absl::Status EmitCublasLtMatmulThunkF8(const HloCustomCallInstruction* instr); #endif // GOOGLE_CUDA || TF_HIPBLASLT #if GOOGLE_CUDA - absl::Status EmitCublasLtMatmulThunkF8(const HloCustomCallInstruction* instr); absl::Status EmitConvolutionReorderThunk( const HloCustomCallInstruction* instr); absl::Status EmitNormThunk(const HloCustomCallInstruction* instr); diff --git a/xla/service/gpu/matmul_utils.cc b/xla/service/gpu/matmul_utils.cc index d25e1883ac95d1..4c688980793bea 100644 --- a/xla/service/gpu/matmul_utils.cc +++ b/xla/service/gpu/matmul_utils.cc @@ -390,6 +390,8 @@ absl::StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, switch (output_shape.element_type()) { case F8E4M3FN: case F8E5M2: + case F8E4M3FNUZ: + case F8E5M2FNUZ: case F16: case BF16: case F32: diff --git a/xla/service/gpu/tests/gemm_rewrite_test.cc b/xla/service/gpu/tests/gemm_rewrite_test.cc index e1fbee9b68be1c..dc4f7fe74909e2 100644 --- a/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -70,15 +70,16 @@ class GemmRewriteTest : public GpuCodegenTest { const se::GpuComputeCapability& GpuComputeComp() { return device_desc().gpu_compute_capability(); } - se::GpuComputeCapability CudaHopperOrRocm() { + se::GpuComputeCapability CudaHopperOrRocmMI300() { return std::visit( VariantVisitor{[](const se::CudaComputeCapability&) { return se::GpuComputeCapability{ se::CudaComputeCapability{ se::CudaComputeCapability::HOPPER, 0}}; }, - [](const se::RocmComputeCapability& rocm) { - return se::GpuComputeCapability{rocm}; + [](const se::RocmComputeCapability&) { + return se::GpuComputeCapability{ + se::RocmComputeCapability{"gfx942"}}; }}, GpuComputeComp()); } @@ -87,46 +88,43 @@ class GemmRewriteTest : public GpuCodegenTest { False, // check always fails True, // check always succeeds }; - // switch based on architecture only + // Switch based on GPU platform only: true/false for both bool CudaOrRocmCheck(Switch cuda_set, Switch rocm_set) { - return std::visit( - VariantVisitor{[cuda_set](const se::CudaComputeCapability&) { - return cuda_set == Switch::True; - }, - [rocm_set](const se::RocmComputeCapability&) { - return rocm_set == Switch::True; - }}, - GpuComputeComp()); + return CudaOrRocmCheck( + [cuda_set](const se::CudaComputeCapability&) { + return cuda_set == Switch::True; + }, + [rocm_set](const se::RocmComputeCapability&) { + return rocm_set == Switch::True; + }); } - // major version check for CUDA and true/false for rocm + // Major version check for CUDA and true/false for ROCM bool CudaOrRocmCheck(int cuda_major, Switch rocm_set) { return CudaOrRocmCheck(cuda_major, 0, rocm_set); } - // full version check for CUDA and true/false for rocm + // Full version check for CUDA and true/false for ROCM bool CudaOrRocmCheck(int cuda_major, int cuda_minor, Switch rocm_set) { - return std::visit( - VariantVisitor{ - [cuda_major, cuda_minor](const se::CudaComputeCapability& cc) { - return cc.IsAtLeast(cuda_major, cuda_minor); - }, - [rocm_set](const se::RocmComputeCapability&) { - return rocm_set == Switch::True; - }, + return CudaOrRocmCheck(cuda_major, cuda_minor, + [rocm_set](const se::RocmComputeCapability&) { + return rocm_set == Switch::True; + }); + } + // Full version check for CUDA and generic version for ROCM + bool CudaOrRocmCheck( + int cuda_major, int cuda_minor, + absl::AnyInvocable rocm_fun) { + return CudaOrRocmCheck( + [cuda_major, cuda_minor](const se::CudaComputeCapability& cc) { + return cc.IsAtLeast(cuda_major, cuda_minor); }, - GpuComputeComp()); + std::move(rocm_fun)); } - // most generic check: passes if NULL function is specified + // The most generic version for both platforms bool CudaOrRocmCheck( absl::AnyInvocable cuda_fun, absl::AnyInvocable rocm_fun) { - return std::visit( - VariantVisitor{[&cuda_fun](const se::CudaComputeCapability& cc) { - return (cuda_fun ? cuda_fun(cc) : true); - }, - [&rocm_fun](const se::RocmComputeCapability& cc) { - return (rocm_fun ? rocm_fun(cc) : true); - }}, - GpuComputeComp()); + return std::visit(VariantVisitor{std::move(cuda_fun), std::move(rocm_fun)}, + GpuComputeComp()); } DebugOptions GetDebugOptionsForTest() override { @@ -4546,76 +4544,137 @@ ENTRY test { } class ParameterizedFp8GemmRewriteTest : public ParameterizedGemmRewriteTest { + public: + ParameterizedFp8GemmRewriteTest() { + replacements_[kF8E4M3DatatypePlaceholder] = +#if GOOGLE_CUDA + "f8e4m3fn"; +#else + "f8e4m3fnuz"; +#endif + replacements_[kF8E5M2DatatypePlaceholder] = +#if GOOGLE_CUDA + "f8e5m2"; +#else + "f8e5m2fnuz"; +#endif + replacements_[kF8E4M3AmaxPlaceholder] = +#if GOOGLE_CUDA + "448."; +#else + "240."; +#endif + } + protected: // Check the HLO runs and has an FP8 cuBLAS LT custom call on supported // architectures (Ada, Hopper, and later). void CheckFp8IfSupported(absl::string_view hlo_text, ErrorSpec error_spec = ErrorSpec{1e-2, 1e-2}) { - if (!CudaOrRocmCheck(8, 9, Switch::False)) { + if (!CudaOrRocmCheck(8, 9, [](const se::RocmComputeCapability& cc) { + return cc.has_fp8_support(); + })) { return; } - EXPECT_TRUE(RunAndCompare(hlo_text, error_spec)); + std::string replaced_hlo_text = + absl::StrReplaceAll(hlo_text, replacements_); + EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(hlo_text, replacements_), + error_spec)); // Most FP8 tests directly create a GemmRewriter and check the output. // Here, also run the entire HLO pass pipeline to ensure no other passes // interfere with GemmRewriter's pattern matching. TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, - GetOptimizedModule(hlo_text)); + GetOptimizedModule(replaced_hlo_text)); const HloInstruction* call = FindInstruction(optimized_module.get(), HloOpcode::kCustomCall); ASSERT_NE(call, nullptr); EXPECT_EQ(call->custom_call_target(), "__cublas$lt$matmul$f8"); } - void SetUp() override { - if (CudaOrRocmCheck(Switch::False, Switch::True)) { - GTEST_SKIP() << "F8 gemm rewrite is not yet supported on ROCm platform"; + + void MatchOptimizedHlo(absl::string_view hlo, const absl::string_view pattern, + bool print_operand_shape = false) { + GemmRewriteTest::MatchOptimizedHlo( + absl::StrReplaceAll(hlo, replacements_), + absl::StrReplaceAll(pattern, replacements_), print_operand_shape); + } + + void RunAndFilecheckHloRewrite( + absl::string_view hlo, HloPassInterface&& hlo_pass, + std::optional expected, + std::function after_pass_checks = nullptr, + const HloModuleConfig* config = nullptr) { + if (expected.has_value()) { + std::string replaced_pattern = + absl::StrReplaceAll(expected.value(), replacements_); + GemmRewriteTest::RunAndFilecheckHloRewrite( + absl::StrReplaceAll(hlo, replacements_), std::move(hlo_pass), + replaced_pattern, after_pass_checks, config); } } + + StatusOr> ParseAndReturnVerifiedModule( + absl::string_view hlo_text, int64_t replica_count = 1, + int64_t num_partitions = 1) { + return GemmRewriteTest::ParseAndReturnVerifiedModule( + absl::StrReplaceAll(hlo_text, replacements_)); + } + + private: + static constexpr const char* kF8E4M3DatatypePlaceholder{"<>"}; + static constexpr const char* kF8E5M2DatatypePlaceholder{"<>"}; + static constexpr const char* kF8E4M3AmaxPlaceholder{"<>"}; }; TEST_P(ParameterizedFp8GemmRewriteTest, DoNotRewriteToF8OnPreAda) { - if (CudaOrRocmCheck(8, 9, Switch::False)) { - GTEST_SKIP() << "Test requires a pre-Ada GPU."; + if (CudaOrRocmCheck(8, 9, [](const se::RocmComputeCapability& cc) { + return cc.has_fp8_support(); + })) { + GTEST_SKIP() << "Test requires a pre-Ada GPU or an AMD GPU prior to MI300."; } const char* hlo_text = R"( HloModule test ENTRY PreAdaTest { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) - ROOT out = f8e4m3fn[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) + ROOT out = <>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} } )"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-2, 1e-2})); + EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(hlo_text, replacements_), + ErrorSpec{1e-2, 1e-2})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %PreAdaTest ({{.*}}: f8e4m3fn[16,32], {{.*}}: f8e4m3fn[32,16]) -> f8e4m3fn[16,16] { +; CHECK-LABEL: ENTRY %PreAdaTest ({{.*}}: <>[16,32], {{.*}}: <>[32,16]) -> <>[16,16] { ; CHECK: {{.*}} = {{.*}} custom-call({{.*}}, {{.*}}) ; CHECK-DAG: custom_call_target="<>" )"); } TEST_P(ParameterizedFp8GemmRewriteTest, DoNotRewriteOnPreAdaWithF32Output) { - if (CudaOrRocmCheck(8, 9, Switch::False)) { - GTEST_SKIP() << "Test requires a pre-Ada GPU."; + if (CudaOrRocmCheck(8, 9, [](const se::RocmComputeCapability& cc) { + return cc.has_fp8_support(); + })) { + GTEST_SKIP() << "Test requires a pre-Ada GPU or an AMD GPU prior to MI300."; } const char* hlo_text = R"( HloModule test ENTRY PreAdaTest { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) ROOT out = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} } )"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-2, 1e-2})); + EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(hlo_text, replacements_), + ErrorSpec{1e-2, 1e-2})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %PreAdaTest ({{.*}}: f8e4m3fn[16,32], {{.*}}: f8e4m3fn[32,16]) -> f32[16,16] { +; CHECK-LABEL: ENTRY %PreAdaTest ({{.*}}: <>[16,32], {{.*}}: <>[32,16]) -> f32[16,16] { ; CHECK: {{.*}} = {{.*}} custom-call({{.*}}, {{.*}}) ; CHECK-DAG: custom_call_target="<>" )"); @@ -4626,28 +4685,33 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnsupportedTypesF8) { GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + // Test with types unsupported by cuBLAS LT when FP8 is used. cuBLAS LT with // FP8 requires one of the operands to be F8E4M3FN. const char* hlo_text = R"( HloModule test ENTRY unsupported_types { - x = f8e5m2[16,16] parameter(0) - y = f8e5m2[16,16] parameter(1) - ROOT out = f8e5m2[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + x = <>[16,16] parameter(0) + y = <>[16,16] parameter(1) + ROOT out = <>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} } )"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-2, 1e-2})); + EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(hlo_text, replacements_), + ErrorSpec{1e-2, 1e-2})); RunAndFilecheckHloRewrite(hlo_text, GemmRewriter(GpuComputeComp(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %unsupported_types ({{.*}}: f8e5m2[16,16], {{.*}}: f8e5m2[16,16]) -> f8e5m2[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e5m2[16,16]{1,0} parameter(0) +; CHECK-LABEL: ENTRY %unsupported_types ({{.*}}: <>[16,16], {{.*}}: <>[16,16]) -> <>[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,16]{1,0} parameter(0) ; CHECK-NEXT: [[P0_CONVERT:%[^ ]+]] = f16[16,16]{1,0} convert([[P0]]) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e5m2[16,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[16,16]{1,0} parameter(1) ; CHECK-NEXT: [[P1_CONVERT:%[^ ]+]] = f16[16,16]{1,0} convert([[P1]]) ; CHECK-NEXT: [[DOT:%[^ ]+]] = f16[16,16]{1,0} dot([[P0_CONVERT]], [[P1_CONVERT]]), lhs_contracting_dims={1}, rhs_contracting_dims={0} -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f8e5m2[16,16]{1,0} convert([[DOT]]) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = <>[16,16]{1,0} convert([[DOT]]) )"); } @@ -4655,27 +4719,33 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) - ROOT out = f8e4m3fn[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) + ROOT out = <>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} } )"; CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[16,32], {{.*}}: f8e4m3fn[32,16]) -> f8e4m3fn[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16]) -> <>[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[C1:[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f8e4m3fn[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-PTX-NEXT: ROOT [[OUT:%[^ ]+]] = <>[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -4699,12 +4769,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[32,16] convert(y) x_scale = f32[] parameter(2) @@ -4720,12 +4795,12 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[16,32], {{.*}}: f8e4m3fn[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) @@ -4753,12 +4828,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDPaddedF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[13,17] parameter(0) - y = f8e4m3fn[17,31] parameter(1) + x = <>[13,17] parameter(0) + y = <>[17,31] parameter(1) x_f32 = f32[13,17] convert(x) y_f32 = f32[17,31] convert(y) x_scale = f32[] parameter(2) @@ -4774,16 +4854,16 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDPaddedF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[13,17], {{.*}}: f8e4m3fn[17,31], {{.*}}: f32[], {{.*}}: f32[]) -> f32[13,31] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[13,17]{1,0} parameter(0) -; CHECK-NEXT: [[C0:%[^ ]+]] = f8e4m3fn[] constant(0) -; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = f8e4m3fn[16,32]{1,0} pad([[P0]], [[C0]]), padding=0_3x0_15 -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[17,31]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[31,17]{1,0} transpose([[P1]]), dimensions={1,0} -; CHECK-NEXT: [[C1:%[^ ]+]] = f8e4m3fn[] constant(0) -; CHECK-NEXT: [[P1_TRANSPOSE_PADDED:%[^ ]+]] = f8e4m3fn[32,32]{1,0} pad([[P1_TRANSPOSE]], [[C1]]) +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[13,17], {{.*}}: <>[17,31], {{.*}}: f32[], {{.*}}: f32[]) -> f32[13,31] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[13,17]{1,0} parameter(0) +; CHECK-NEXT: [[C0:%[^ ]+]] = <>[] constant(0) +; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = <>[16,32]{1,0} pad([[P0]], [[C0]]), padding=0_3x0_15 +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[17,31]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[31,17]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[C1:%[^ ]+]] = <>[] constant(0) +; CHECK-NEXT: [[P1_TRANSPOSE_PADDED:%[^ ]+]] = <>[32,32]{1,0} pad([[P1_TRANSPOSE]], [[C1]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[C4:%[^ ]+]] = f32[] constant(1) @@ -4812,12 +4892,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDBitcastF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[2,8,16] parameter(0) - y = f8e4m3fn[16,16] parameter(1) + x = <>[2,8,16] parameter(0) + y = <>[16,16] parameter(1) x_f32 = f32[2,8,16] convert(x) y_f32 = f32[16,16] convert(y) x_scale = f32[] parameter(2) @@ -4834,7 +4919,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDBitcastF8) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(CudaHopperOrRocm(), /*f8_rewrite=*/true); + GemmRewriter pass(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -4848,12 +4933,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDUnaryOpsF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[3] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[3] parameter(0) + y = <>[32,16] parameter(1) x_f32 = f32[3] convert(x) y_f32 = f32[32,16] convert(y) x_scale = f32[] parameter(2) @@ -4871,21 +4961,22 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDUnaryOpsF8) { } )"; + CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[3], {{.*}}: f8e4m3fn[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[3]{0} parameter(0) +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[3], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[3]{0} parameter(0) ; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(0) -; CHECK-NEXT: [[C0_CONVERT:%[^ ]+]] = f8e4m3fn[] convert([[C0]]) -; CHECK-NEXT: [[P0_U0:%[^ ]+]] = f8e4m3fn[30]{0} pad([[P0]], [[C0_CONVERT]]), padding=0_27 -; CHECK-NEXT: [[P0_U1:%[^ ]+]] = f8e4m3fn[30,8,5]{2,1,0} broadcast([[P0_U0]]), dimensions={0} -; CHECK-NEXT: [[P0_U2:%[^ ]+]] = f8e4m3fn[16,8,4]{2,1,0} slice([[P0_U1]]), slice={[2:18], [0:8], [0:4]} -; CHECK-NEXT: [[P0_U3:%[^ ]+]] = f8e4m3fn[16,32]{1,0} reshape([[P0_U2]]) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[C0_CONVERT:%[^ ]+]] = <>[] convert([[C0]]) +; CHECK-NEXT: [[P0_U0:%[^ ]+]] = <>[30]{0} pad([[P0]], [[C0_CONVERT]]), padding=0_27 +; CHECK-NEXT: [[P0_U1:%[^ ]+]] = <>[30,8,5]{2,1,0} broadcast([[P0_U0]]), dimensions={0} +; CHECK-NEXT: [[P0_U2:%[^ ]+]] = <>[16,8,4]{2,1,0} slice([[P0_U1]]), slice={[2:18], [0:8], [0:4]} +; CHECK-NEXT: [[P0_U3:%[^ ]+]] = <>[16,32]{1,0} reshape([[P0_U2]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) @@ -4913,12 +5004,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDDynamicSliceF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[32,32] parameter(0) - y = f8e4m3fn[16,32] parameter(1) + x = <>[32,32] parameter(0) + y = <>[16,32] parameter(1) zero = s32[] constant(0) x_f32 = f32[32,32] convert(x) y_f32 = f32[16,32] convert(y) @@ -4932,21 +5028,22 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDDynamicSliceF8) { ROOT dot_a = f32[16,16] dot(dyn_slice, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={1} } )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(CudaHopperOrRocm(), /*f8_rewrite=*/true); + GemmRewriter pass(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[32,32], {{.*}}: f8e4m3fn[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[32,32]{1,0} parameter(0) +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[32,32], {{.*}}: <>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[32,32]{1,0} parameter(0) ; CHECK-NEXT: [[C0:%[^ ]+]] = s32[] constant(0) -; CHECK-NEXT: [[DYN_SLICE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} dynamic-slice([[P0]], [[C0]], [[C0]]), dynamic_slice_sizes={16,32} -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(1) +; CHECK-NEXT: [[DYN_SLICE:%[^ ]+]] = <>[16,32]{1,0} dynamic-slice([[P0]], [[C0]], [[C0]]), dynamic_slice_sizes={16,32} +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) @@ -4974,12 +5071,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDSelectF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[16,32] parameter(1) + x = <>[16,32] parameter(0) + y = <>[16,32] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[16,32] convert(y) x_scale = f32[] parameter(2) @@ -4995,24 +5097,25 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDSelectF8) { ROOT dot_a = f32[16,16] dot(x_unscaled, select_a), lhs_contracting_dims={1}, rhs_contracting_dims={1} } )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(CudaHopperOrRocm(), /*f8_rewrite=*/true); + GemmRewriter pass(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[16,32], {{.*}}: f8e4m3fn[16,32], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: pred[16,32]) -> f32[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[16,32], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: pred[16,32]) -> f32[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) ; CHECK-NEXT: [[P4:%[^ ]+]] = pred[16,32]{1,0} parameter(4) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(1) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) ; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(0) ; CHECK-NEXT: [[C0_BCAST:%[^ ]+]] = f32[16,32]{1,0} broadcast([[C0]]), dimensions={} -; CHECK-NEXT: [[C0_CONVERT:%[^ ]+]] = f8e4m3fn[16,32]{1,0} convert([[C0_BCAST]]) -; CHECK-NEXT: [[SELECT:%[^ ]+]] = f8e4m3fn[16,32]{1,0} select([[P4]], [[P1]], [[C0_CONVERT]]) +; CHECK-NEXT: [[C0_CONVERT:%[^ ]+]] = <>[16,32]{1,0} convert([[C0_BCAST]]) +; CHECK-NEXT: [[SELECT:%[^ ]+]] = <>[16,32]{1,0} select([[P4]], [[P1]], [[C0_CONVERT]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) @@ -5041,12 +5144,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[16,32] parameter(1) + x = <>[16,32] parameter(0) + y = <>[16,32] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[16,32] convert(y) x_scale = f32[] parameter(2) @@ -5062,9 +5170,10 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ROOT dot_a = f32[16,16] dot(x_unscaled, select_a), lhs_contracting_dims={1}, rhs_contracting_dims={1} } )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(CudaHopperOrRocm(), /*f8_rewrite=*/true); + GemmRewriter pass(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_FALSE(changed); } @@ -5073,12 +5182,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, BatchedScaledABUnscaledDF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[10,16,32] parameter(0) - y = f8e4m3fn[10,32,16] parameter(1) + x = <>[10,16,32] parameter(0) + y = <>[10,32,16] parameter(1) x_f32 = f32[10,16,32] convert(x) y_f32 = f32[10,32,16] convert(y) x_scale = f32[] parameter(2) @@ -5094,12 +5208,12 @@ TEST_P(ParameterizedFp8GemmRewriteTest, BatchedScaledABUnscaledDF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[10,16,32], {{.*}}: f8e4m3fn[10,32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[10,16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[10,16,32]{2,1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[10,32,16]{2,1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[10,16,32]{2,1,0} transpose([[P1]]), dimensions={0,2,1} +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[10,16,32], {{.*}}: <>[10,32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[10,16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[10,16,32]{2,1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[10,32,16]{2,1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[10,16,32]{2,1,0} transpose([[P1]]), dimensions={0,2,1} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) @@ -5127,12 +5241,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABAlphaDF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[32,16] convert(y) x_scale = f32[] parameter(2) @@ -5151,13 +5270,13 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABAlphaDF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[16,32], {{.*}}: f8e4m3fn[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) @@ -5185,12 +5304,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDReluActivationF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[32,16] convert(y) x_scale = f32[] parameter(2) @@ -5209,13 +5333,13 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDReluActivationF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[16,32], {{.*}}: f8e4m3fn[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) @@ -5244,11 +5368,16 @@ TEST_P(ParameterizedFp8GemmRewriteTest, #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_bf16 = bf16[16,32] convert(x) y_bf16 = bf16[32,16] convert(y) x_scale = bf16[] parameter(2) @@ -5283,19 +5412,20 @@ TEST_P(ParameterizedFp8GemmRewriteTest, CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[16,32], {{.*}}: f8e4m3fn[32,16], {{.*}}: bf16[], {{.*}}: bf16[], {{.*}}: bf16[16]) -> bf16[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: bf16[], {{.*}}: bf16[], {{.*}}: bf16[16]) -> bf16[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = bf16[] parameter(2) ; CHECK-NEXT: [[XS:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = bf16[] parameter(3) ; CHECK-NEXT: [[XS1:%[^ ]+]] = f32[] convert([[P3]]) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[B:%[^ ]+]] = bf16[16]{0} parameter(4) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]], [[B]]), +; CHECK-PTX-NEXT: [[B:%[^ ]+]] = bf16[16]{0} parameter(4) +; CHECK-PTX-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]], [[B]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5310,7 +5440,8 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-DAG: "precision_config":{ ; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] ; CHECK-DAG: } -; CHECK-DAG: "epilogue":"BIAS_GELU" +; CHECK-PTX-DAG: "epilogue":"BIAS_GELU" +; CHECK-GCN-DAG: "epilogue":"DEFAULT" ; CHECK: } )"); } @@ -5320,11 +5451,16 @@ TEST_P(ParameterizedFp8GemmRewriteTest, #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_bf16 = bf16[16,32] convert(x) y_bf16 = bf16[32,16] convert(y) x_scale = bf16[] parameter(2) @@ -5355,19 +5491,22 @@ TEST_P(ParameterizedFp8GemmRewriteTest, )"; CheckFp8IfSupported(hlo_text); + // Currently, hipBlasLt does not support output datatype bf16 for fp8 matmul. + // And no fusion was done for such cases. RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[16,32], {{.*}}: f8e4m3fn[32,16], {{.*}}: bf16[], {{.*}}: bf16[]) -> bf16[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: bf16[], {{.*}}: bf16[]) -> bf16[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = bf16[] parameter(2) ; CHECK-NEXT: [[XS:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = bf16[] parameter(3) ; CHECK-NEXT: [[XS1:%[^ ]+]] = f32[] convert([[P3]]) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-PTX-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5382,7 +5521,8 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-DAG: "precision_config":{ ; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] ; CHECK-DAG: } -; CHECK-DAG: "epilogue":"GELU" +; CHECK-PTX-DAG: "epilogue":"GELU" +; CHECK-GCN-DAG: "epilogue":"DEFAULT" ; CHECK: } )"); } @@ -5391,12 +5531,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, InvScaledABUnscaledDF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[32,16] convert(y) x_scale = f32[] parameter(2) @@ -5412,7 +5557,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, InvScaledABUnscaledDF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( ; CHECK: custom_call_target="__cublas$lt$matmul$f8", )"); @@ -5422,12 +5567,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) b = f32[16,16] parameter(2) x_f32 = f32[16,32] convert(x) y_f32 = f32[32,16] convert(y) @@ -5445,13 +5595,13 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[16,32], {{.*}}: f8e4m3fn[32,16], {{.*}}: f32[16,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[16,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[C0:%[^ ]+]] = f32[16,16]{1,0} parameter(2) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) @@ -5480,12 +5630,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasPaddedF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[14,31] parameter(0) - y = f8e4m3fn[31,14] parameter(1) + x = <>[14,31] parameter(0) + y = <>[31,14] parameter(1) b = f32[14,14] parameter(2) x_f32 = f32[14,31] convert(x) y_f32 = f32[31,14] convert(y) @@ -5503,17 +5658,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasPaddedF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[14,31], {{.*}}: f8e4m3fn[31,14], {{.*}}: f32[14,14], {{.*}}: f32[], {{.*}}: f32[]) -> f32[14,14] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[14,31]{1,0} parameter(0) -; CHECK-NEXT: [[C0:%[^ ]+]] = f8e4m3fn[] constant(0) -; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = f8e4m3fn[16,32]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_1 -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[31,14]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[14,31]{1,0} transpose([[P1]]), dimensions={1,0} -; CHECK-NEXT: [[C1:%[^ ]+]] = f8e4m3fn[] constant(0) -; CHECK-NEXT: [[P1_TRANSPOSE_PADDED:%[^ ]+]] = f8e4m3fn[16,32]{1,0} pad([[P1_TRANSPOSE]], [[C1]]), padding=0_2x0_1 +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[14,31], {{.*}}: <>[31,14], {{.*}}: f32[14,14], {{.*}}: f32[], {{.*}}: f32[]) -> f32[14,14] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[14,31]{1,0} parameter(0) +; CHECK-NEXT: [[C0:%[^ ]+]] = <>[] constant(0) +; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = <>[16,32]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_1 +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[31,14]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[14,31]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[C1:%[^ ]+]] = <>[] constant(0) +; CHECK-NEXT: [[P1_TRANSPOSE_PADDED:%[^ ]+]] = <>[16,32]{1,0} pad([[P1_TRANSPOSE]], [[C1]]), padding=0_2x0_1 ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[14,14]{1,0} parameter(2) ; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(0) ; CHECK-NEXT: [[P2_PADDED:%[^ ]+]] = f32[16,16]{1,0} pad([[P2]], [[C2]]), padding=0_2x0_2 @@ -5545,12 +5700,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledDF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) z_scale = f32[] parameter(2) z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={} dot_a = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} @@ -5560,24 +5720,25 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledDF8) { c2 = f32[] constant(448.) c2_bcast = f32[16,16] broadcast(c2), dimensions={} dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast) - ROOT dot_a_f8 = f8e4m3fn[16,16] convert(dot_a_clamped) + ROOT dot_a_f8 = <>[16,16] convert(dot_a_clamped) } )"; CheckFp8IfSupported(hlo_text, ErrorSpec{1e-2, 1e-1}); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[16,32], {{.*}}: f8e4m3fn[32,16], {{.*}}: f32[]) -> f8e4m3fn[16,16] { -; CHECK: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[]) -> <>[16,16] { +; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) -; CHECK-NEXT: [[P2_INV:%[^ ]+]] = f32[] divide([[C2]], [[P2]]) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f8e4m3fn[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[P2_INV]]), +; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) +; CHECK-PTX-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) +; CHECK-PTX-NEXT: [[P2_INV:%[^ ]+]] = f32[] divide([[C2]], [[P2]]) +; CHECK-PTX-NEXT: ROOT [[OUT:%[^ ]+]] = <>[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[P2_INV]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5601,12 +5762,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[32,16] convert(y) x_scale = f32[] parameter(2) @@ -5619,31 +5785,32 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDF8) { y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast) dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} dot_a_scaled = f32[16,16] divide(dot_a, z_scale_bcast) - c1 = f32[] constant(-448.) + c1 = f32[] constant(-<>) c1_bcast = f32[16,16] broadcast(c1), dimensions={} - c2 = f32[] constant(448.) + c2 = f32[] constant(<>) c2_bcast = f32[16,16] broadcast(c2), dimensions={} dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast) - ROOT dot_a_f8 = f8e4m3fn[16,16] convert(dot_a_clamped) + ROOT dot_a_f8 = <>[16,16] convert(dot_a_clamped) } )"; CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[16,32], {{.*}}: f8e4m3fn[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> f8e4m3fn[16,16] { -; CHECK: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> <>[16,16] { +; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f8e4m3fn[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), +; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) +; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) +; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) +; CHECK-PTX-NEXT: ROOT [[OUT:%[^ ]+]] = <>[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5667,12 +5834,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABInvScaledDF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[32,16] convert(y) x_scale = f32[] parameter(2) @@ -5685,19 +5857,19 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABInvScaledDF8) { y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast) dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} dot_a_scaled = f32[16,16] multiply(dot_a, z_scale_bcast) - c1 = f32[] constant(-448.) + c1 = f32[] constant(-<>) c1_bcast = f32[16,16] broadcast(c1), dimensions={} - c2 = f32[] constant(448.) + c2 = f32[] constant(<>) c2_bcast = f32[16,16] broadcast(c2), dimensions={} dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast) - ROOT dot_a_f8 = f8e4m3fn[16,16] convert(dot_a_clamped) + ROOT dot_a_f8 = <>[16,16] convert(dot_a_clamped) } )"; CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( ; CHECK-NOT: divide @@ -5711,11 +5883,16 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[32,16] convert(y) x_scale = f32[] parameter(2) @@ -5731,30 +5908,31 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) { dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} relu_a = f32[16,16] maximum(dot_a, c_bcast) relu_a_scaled = f32[16,16] divide(relu_a, z_scale_bcast) - c1 = f32[] constant(-448.) + c1 = f32[] constant(-<>) c1_bcast = f32[16,16] broadcast(c1), dimensions={} - c2 = f32[] constant(448.) + c2 = f32[] constant(<>) c2_bcast = f32[16,16] broadcast(c2), dimensions={} relu_a_clamped = f32[16,16] clamp(c1_bcast, relu_a_scaled, c2_bcast) - ROOT out = f8e4m3fn[16,16] convert(relu_a_clamped) + ROOT out = <>[16,16] convert(relu_a_clamped) } )"; CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[16,32], {{.*}}: f8e4m3fn[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> f8e4m3fn[16,16] { -; CHECK: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> <>[16,16] { +; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f8e4m3fn[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), +; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) +; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) +; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) +; CHECK-PTX-NEXT: ROOT [[OUT:%[^ ]+]] = <>[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5778,12 +5956,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f16 = f16[16,32] convert(x) y_f16 = f16[32,16] convert(y) b = f16[16,16] parameter(2) @@ -5798,31 +5981,32 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasF8) { dot_a = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} dot_a_bias = f16[16,16] add(dot_a, b) dot_a_scaled = f16[16,16] divide(dot_a_bias, z_scale_bcast) - c1 = f16[] constant(-448.) + c1 = f16[] constant(-<>) c1_bcast = f16[16,16] broadcast(c1), dimensions={} - c2 = f16[] constant(448.) + c2 = f16[] constant(<>) c2_bcast = f16[16,16] broadcast(c2), dimensions={} dot_a_clamped = f16[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast) - ROOT dot_a_f8 = f8e4m3fn[16,16] convert(dot_a_clamped) + ROOT dot_a_f8 = <>[16,16] convert(dot_a_clamped) } )"; CheckFp8IfSupported(hlo_text, ErrorSpec{0.1, 0.1}); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[16,32], {{.*}}: f8e4m3fn[32,16], {{.*}}: f16[16,16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> f8e4m3fn[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f16[16,16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> <>[16,16] { +; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[C0:%[^ ]+]] = f16[16,16]{1,0} parameter(2) ; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3) ; CHECK: [[P3:%[^ ]+]] = f16[] parameter(4) ; CHECK: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK: [[P4:%[^ ]+]] = f16[] parameter(5) -; CHECK: ROOT [[OUT:%[^ ]+]] = f8e4m3fn[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[C1]], [[DUMMY2:%[^ ]+]]), +; CHECK-PTX: [[P4:%[^ ]+]] = f16[] parameter(5) +; CHECK-PTX: ROOT [[OUT:%[^ ]+]] = <>[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[C1]], [[DUMMY2:%[^ ]+]]), +; CHECK-GCN: [[OUT:%[^ ]+]] = f16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[C1]], [[DUMMY2:%[^ ]+]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5846,12 +6030,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDVectorBiasF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f16 = f16[16,32] convert(x) y_f16 = f16[32,16] convert(y) b = f16[16] parameter(2) @@ -5867,36 +6056,37 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDVectorBiasF8) { dot_a = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} dot_a_bias = f16[16,16] add(dot_a, b_bcast) dot_a_scaled = f16[16,16] divide(dot_a_bias, z_scale_bcast) - c1 = f16[] constant(-448.) + c1 = f16[] constant(-<>) c1_bcast = f16[16,16] broadcast(c1), dimensions={} - c2 = f16[] constant(448.) + c2 = f16[] constant(<>) c2_bcast = f16[16,16] broadcast(c2), dimensions={} dot_a_clamped = f16[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast) - ROOT dot_a_f8 = f8e4m3fn[16,16] convert(dot_a_clamped) + ROOT dot_a_f8 = <>[16,16] convert(dot_a_clamped) } )"; CheckFp8IfSupported(hlo_text, ErrorSpec{0.1, 0.1}); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[16,32], {{.*}}: f8e4m3fn[32,16], {{.*}}: f16[16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> f8e4m3fn[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f16[16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> <>[16,16] { +; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3) ; CHECK-NEXT: [[CV:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) ; CHECK-NEXT: [[CV1:%[^ ]+]] = f32[] convert([[P3]]) ; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[C2:%[^ ]+]] = f16[] constant(1) -; CHECK-NEXT: [[P4:%[^ ]+]] = f16[] parameter(5) -; CHECK-NEXT: [[DV:%[^ ]+]] = f16[] divide([[C2]], [[P4]]) -; CHECK-NEXT: [[CV2:%[^ ]+]] = f32[] convert([[DV]]) +; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f16[] constant(1) +; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f16[] parameter(5) +; CHECK-PTX-NEXT: [[DV:%[^ ]+]] = f16[] divide([[C2]], [[P4]]) +; CHECK-PTX-NEXT: [[CV2:%[^ ]+]] = f32[] convert([[DV]]) ; CHECK-NEXT: [[VB:%[^ ]+]] = f16[16]{0} parameter(2) -; CHECK: ROOT [[OUT:%[^ ]+]] = f8e4m3fn[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[CV2]], [[VB]]), +; CHECK-PTX: ROOT [[OUT:%[^ ]+]] = <>[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[CV2]], [[VB]]), +; CHECK-GCN: [[OUT:%[^ ]+]] = f16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5920,12 +6110,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF32VectorBiasF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[32,16] convert(y) b = f32[16] parameter(2) @@ -5946,12 +6141,12 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF32VectorBiasF8) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[16,32], {{.*}}: f8e4m3fn[32,16], {{.*}}: f32[16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { -; CHECK: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { +; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) ; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) @@ -5982,12 +6177,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) b = f16[16] parameter(2) b_bcast = f16[16,16] broadcast(b), dimensions={1} x_f32 = f16[16,32] convert(x) @@ -6008,12 +6208,12 @@ TEST_P(ParameterizedFp8GemmRewriteTest, CheckFp8IfSupported(hlo_text, ErrorSpec{2e-3, 0.}); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[16,32], {{.*}}: f8e4m3fn[32,16], {{.*}}: f16[16], {{.*}}: f16[], {{.*}}: f16[]) -> f16[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f16[16], {{.*}}: f16[], {{.*}}: f16[]) -> f16[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3) ; CHECK-NEXT: [[CV:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) @@ -6044,11 +6244,16 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12"; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[4,16,16] parameter(0) - y = f8e4m3fn[16,32] parameter(1) + x = <>[4,16,16] parameter(0) + y = <>[16,32] parameter(1) b = f32[32] parameter(2) b_f16 = f16[32] convert(b) b_bcast = f16[4,16,32] broadcast(b_f16), dimensions={2} @@ -6069,7 +6274,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(CudaHopperOrRocm(), /*f8_rewrite=*/true); + GemmRewriter pass(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -6079,13 +6284,13 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { .WithShape(F16, {4, 16, 32}))); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[4,16,16], {{.*}}: f8e4m3fn[16,32], {{.*}}: f32[32], {{.*}}: f16[], {{.*}}: f16[]) -> f16[4,16,32] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[4,16,16]{2,1,0} parameter(0) -; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f8e4m3fn[64,16]{1,0} bitcast([[P0]]) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[32,16]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[4,16,16], {{.*}}: <>[16,32], {{.*}}: f32[32], {{.*}}: f16[], {{.*}}: f16[]) -> f16[4,16,32] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[4,16,16]{2,1,0} parameter(0) +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = <>[64,16]{1,0} bitcast([[P0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[32,16]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3) ; CHECK-NEXT: [[P2_CV:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) @@ -6119,11 +6324,16 @@ TEST_P(ParameterizedFp8GemmRewriteTest, #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12"; #endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[4,15,15] parameter(0) - y = f8e4m3fn[15,31] parameter(1) + x = <>[4,15,15] parameter(0) + y = <>[15,31] parameter(1) b = f32[31] parameter(2) b_f16 = f16[31] convert(b) b_bcast = f16[4,15,31] broadcast(b_f16), dimensions={2} @@ -6144,7 +6354,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(CudaHopperOrRocm(), /*f8_rewrite=*/true); + GemmRewriter pass(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -6156,17 +6366,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, .WithShape(F16, {4, 15, 31}))); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[4,15,15], {{.*}}: f8e4m3fn[15,31], {{.*}}: f32[31], {{.*}}: f16[], {{.*}}: f16[]) -> f16[4,15,31] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[4,15,15]{2,1,0} parameter(0) -; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f8e4m3fn[60,15]{1,0} bitcast([[P0]]) -; CHECK-NEXT: [[C1:%[^ ]+]] = f8e4m3fn[] constant(0) -; CHECK-NEXT: [[P0_PAD:%[^ ]+]] = f8e4m3fn[64,16]{1,0} pad([[P0_BITCAST]], [[C1]]), padding=0_4x0_1 -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[15,31]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[31,15]{1,0} transpose([[P1]]), dimensions={1,0} -; CHECK-NEXT: [[C2:%[^ ]+]] = f8e4m3fn[] constant(0) -; CHECK-NEXT: [[P1_PAD:%[^ ]+]] = f8e4m3fn[32,16]{1,0} pad([[P1_TRANSPOSE]], [[C2]]), padding=0_1x0_1 +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[4,15,15], {{.*}}: <>[15,31], {{.*}}: f32[31], {{.*}}: f16[], {{.*}}: f16[]) -> f16[4,15,31] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[4,15,15]{2,1,0} parameter(0) +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = <>[60,15]{1,0} bitcast([[P0]]) +; CHECK-NEXT: [[C1:%[^ ]+]] = <>[] constant(0) +; CHECK-NEXT: [[P0_PAD:%[^ ]+]] = <>[64,16]{1,0} pad([[P0_BITCAST]], [[C1]]), padding=0_4x0_1 +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[15,31]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[31,15]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[C2:%[^ ]+]] = <>[] constant(0) +; CHECK-NEXT: [[P1_PAD:%[^ ]+]] = <>[32,16]{1,0} pad([[P1_TRANSPOSE]], [[C2]]), padding=0_1x0_1 ; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3) ; CHECK-NEXT: [[P2_CV:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) @@ -6202,11 +6412,16 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12"; #endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[4,16,16] parameter(0) - y = f8e4m3fn[16,32] parameter(1) + x = <>[4,16,16] parameter(0) + y = <>[16,32] parameter(1) b = f32[4,16,32] parameter(2) x_f32 = f32[4,16,16] convert(x) y_f32 = f32[16,32] convert(y) @@ -6225,7 +6440,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(CudaHopperOrRocm(), /*f8_rewrite=*/true); + GemmRewriter pass(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -6235,13 +6450,13 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) { .WithShape(F32, {4, 16, 32}))); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[4,16,16], {{.*}}: f8e4m3fn[16,32], {{.*}}: f32[4,16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[4,16,32] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[4,16,16]{2,1,0} parameter(0) -; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f8e4m3fn[64,16]{1,0} bitcast([[P0]]) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[32,16]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[4,16,16], {{.*}}: <>[16,32], {{.*}}: f32[4,16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[4,16,32] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[4,16,16]{2,1,0} parameter(0) +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = <>[64,16]{1,0} bitcast([[P0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[32,16]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[B:%[^ ]+]] = f32[4,16,32]{2,1,0} parameter(2) ; CHECK-NEXT: [[B_BITCAST:%[^ ]+]] = f32[64,32]{1,0} bitcast([[B]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) @@ -6273,11 +6488,16 @@ TEST_P(ParameterizedFp8GemmRewriteTest, #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12"; #endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[3,15,15] parameter(0) - y = f8e4m3fn[15,31] parameter(1) + x = <>[3,15,15] parameter(0) + y = <>[15,31] parameter(1) b = f32[3,15,31] parameter(2) x_f32 = f32[3,15,15] convert(x) y_f32 = f32[15,31] convert(y) @@ -6296,7 +6516,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(CudaHopperOrRocm(), /*f8_rewrite=*/true); + GemmRewriter pass(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -6308,17 +6528,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, .WithShape(F32, {3, 15, 31}))); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[3,15,15], {{.*}}: f8e4m3fn[15,31], {{.*}}: f32[3,15,31], {{.*}}: f32[], {{.*}}: f32[]) -> f32[3,15,31] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[3,15,15]{2,1,0} parameter(0) -; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f8e4m3fn[45,15]{1,0} bitcast([[P0]]) -; CHECK-NEXT: [[C1:%[^ ]+]] = f8e4m3fn[] constant(0) -; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = f8e4m3fn[48,16]{1,0} pad([[P0_BITCAST]], [[C1]]), padding=0_3x0_1 -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[15,31]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[31,15]{1,0} transpose([[P1]]), dimensions={1,0} -; CHECK-NEXT: [[C2:%[^ ]+]] = f8e4m3fn[] constant(0) -; CHECK-NEXT: [[P1_PADDED:%[^ ]+]] = f8e4m3fn[32,16]{1,0} pad([[P1_TRANSPOSE]], [[C2]]), padding=0_1x0_1 +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[3,15,15], {{.*}}: <>[15,31], {{.*}}: f32[3,15,31], {{.*}}: f32[], {{.*}}: f32[]) -> f32[3,15,31] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[3,15,15]{2,1,0} parameter(0) +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = <>[45,15]{1,0} bitcast([[P0]]) +; CHECK-NEXT: [[C1:%[^ ]+]] = <>[] constant(0) +; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = <>[48,16]{1,0} pad([[P0_BITCAST]], [[C1]]), padding=0_3x0_1 +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[15,31]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[31,15]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[C2:%[^ ]+]] = <>[] constant(0) +; CHECK-NEXT: [[P1_PADDED:%[^ ]+]] = <>[32,16]{1,0} pad([[P1_TRANSPOSE]], [[C2]]), padding=0_1x0_1 ; CHECK-NEXT: [[B:%[^ ]+]] = f32[3,15,31]{2,1,0} parameter(2) ; CHECK-NEXT: [[B_BITCAST:%[^ ]+]] = f32[45,31]{1,0} bitcast([[B]]) ; CHECK-NEXT: [[C3:%[^ ]+]] = f32[] constant(0) @@ -6355,11 +6575,16 @@ TEST_P(ParameterizedFp8GemmRewriteTest, #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12"; #endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[48,16] parameter(0) - y = f8e4m3fn[16,32] parameter(1) + x = <>[48,16] parameter(0) + y = <>[16,32] parameter(1) b = f32[32,16] parameter(2) x_f32 = f32[48,16] convert(x) y_f32 = f32[16,32] convert(y) @@ -6377,17 +6602,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(CudaHopperOrRocm(), /*f8_rewrite=*/true); + GemmRewriter pass(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[48,16], {{.*}}: f8e4m3fn[16,32], {{.*}}: f32[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[32,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[48,16]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[32,16]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[48,16], {{.*}}: <>[16,32], {{.*}}: f32[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[32,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[48,16]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[32,16]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) ; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) @@ -6418,12 +6643,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllGatherF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12"; #endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + absl::string_view hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[16,32] parameter(1) + x = <>[16,32] parameter(0) + y = <>[16,32] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[16,32] convert(y) x_scale = f32[] parameter(2) @@ -6443,14 +6673,14 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllGatherF8) { config.set_num_partitions(8); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[16,32], {{.*}}: f8e4m3fn[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,32] { -; CHECK: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK: [[AG:%[^ ]+]] = f8e4m3fn[16,64]{1,0} all-gather([[P0]]), {{[^ ]+}} -; CHECK: [[P1:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(1) -; CHECK: [[AG1:%[^ ]+]] = f8e4m3fn[64,32]{1,0} all-gather([[P1]]), {{[^ ]+}} -; CHECK: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[32,64]{1,0} transpose([[AG1]]), dimensions={1,0} +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,32] { +; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK: [[AG:%[^ ]+]] = <>[16,64]{1,0} all-gather([[P0]]), {{[^ ]+}} +; CHECK: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) +; CHECK: [[AG1:%[^ ]+]] = <>[64,32]{1,0} all-gather([[P1]]), {{[^ ]+}} +; CHECK: [[P1_TRANSPOSE:%[^ ]+]] = <>[32,64]{1,0} transpose([[AG1]]), dimensions={1,0} ; CHECK: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK: [[C:%[^ ]+]] = f32[] constant(1) @@ -6479,12 +6709,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllToAllF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12"; #endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + absl::string_view hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[16,32] parameter(1) + x = <>[16,32] parameter(0) + y = <>[16,32] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[16,32] convert(y) x_scale = f32[] parameter(2) @@ -6503,12 +6738,12 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllToAllF8) { config.set_num_partitions(8); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[16,32], {{.*}}: f8e4m3fn[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { -; CHECK: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK: [[AA:%[^ ]+]] = f8e4m3fn[16,32]{1,0} all-to-all([[P0]]), {{[^ ]+}} -; CHECK: [[P1:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(1) +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { +; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK: [[AA:%[^ ]+]] = <>[16,32]{1,0} all-to-all([[P0]]), {{[^ ]+}} +; CHECK: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) ; CHECK: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK: [[C:%[^ ]+]] = f32[] constant(1) @@ -6538,12 +6773,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + absl::string_view hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[16,32] parameter(1) + x = <>[16,32] parameter(0) + y = <>[16,32] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[16,32] convert(y) x_scale = f32[] parameter(2) @@ -6562,12 +6802,12 @@ TEST_P(ParameterizedFp8GemmRewriteTest, config.set_num_partitions(8); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[16,32], {{.*}}: f8e4m3fn[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { -; CHECK: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK: [[AA:%[^ ]+]] = f8e4m3fn[16,32]{1,0} collective-permute([[P0]]), {{[^ ]+}} -; CHECK: [[P1:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(1) +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { +; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK: [[AA:%[^ ]+]] = <>[16,32]{1,0} collective-permute([[P0]]), {{[^ ]+}} +; CHECK: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) ; CHECK: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK: [[C:%[^ ]+]] = f32[] constant(1) @@ -6597,12 +6837,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f16 = f16[16,32] convert(x) y_f16 = f16[32,16] convert(y) b = f16[16] parameter(2) @@ -6620,14 +6865,15 @@ TEST_P(ParameterizedFp8GemmRewriteTest, } )"; + CheckFp8IfSupported(hlo_text, ErrorSpec{2e-3, 0.}); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[16,32], {{.*}}: f8e4m3fn[32,16], {{.*}}: f16[16], {{.*}}: f16[16,16], {{.*}}: f16[], {{.*}}: f16[]) -> f16[16,16] { -; CHECK-DAG: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f16[16], {{.*}}: f16[16,16], {{.*}}: f16[], {{.*}}: f16[]) -> f16[16,16] { +; CHECK-DAG: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[MB:%[^ ]+]] = f16[16,16]{1,0} parameter(3) ; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(4) ; CHECK-NEXT: [[CV0:%[^ ]+]] = f32[] convert([[P2]]) @@ -6661,6 +6907,11 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test @@ -6671,8 +6922,8 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) { } ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[32,16] convert(y) x_scale = f32[] parameter(2) @@ -6688,32 +6939,33 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) { c0 = f32[] constant(-inf) amax = f32[] reduce(abs_dot_a, c0), dimensions={0,1}, to_apply=apply dot_a_scaled = f32[16,16] divide(dot_a, z_scale_bcast) - c1 = f32[] constant(-448.) + c1 = f32[] constant(-<>) c1_bcast = f32[16,16] broadcast(c1), dimensions={} - c2 = f32[] constant(448.) + c2 = f32[] constant(<>) c2_bcast = f32[16,16] broadcast(c2), dimensions={} dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast) - dot_a_f8 = f8e4m3fn[16,16] convert(dot_a_clamped) - ROOT out = (f8e4m3fn[16,16], f32[]) tuple(dot_a_f8, amax) + dot_a_f8 = <>[16,16] convert(dot_a_clamped) + ROOT out = (<>[16,16], f32[]) tuple(dot_a_f8, amax) } )"; CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[16,32], {{.*}}: f8e4m3fn[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> (f8e4m3fn[16,16], f32[]) { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]) +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> (<>[16,16], f32[]) { +; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f8e4m3fn[16,16]{1,0}, f32[]) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), +; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) +; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) +; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) +; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[]) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6738,6 +6990,11 @@ TEST_P(ParameterizedFp8GemmRewriteTest, #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + // This is the same as ScaledABScaledDWithDAmaxF8, but uses F16 intermediate // values instead of F32 intermediate values. const char* hlo_text = R"( @@ -6750,8 +7007,8 @@ TEST_P(ParameterizedFp8GemmRewriteTest, } ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f16 = f16[16,32] convert(x) y_f16 = f16[32,16] convert(y) x_scale = f16[] parameter(2) @@ -6767,35 +7024,36 @@ TEST_P(ParameterizedFp8GemmRewriteTest, c0 = f16[] constant(-inf) amax = f16[] reduce(abs_dot_a, c0), dimensions={0,1}, to_apply=apply dot_a_scaled = f16[16,16] divide(dot_a, z_scale_bcast) - c1 = f16[] constant(-448.) + c1 = f16[] constant(-<>) c1_bcast = f16[16,16] broadcast(c1), dimensions={} - c2 = f16[] constant(448.) + c2 = f16[] constant(<>) c2_bcast = f16[16,16] broadcast(c2), dimensions={} dot_a_clamped = f16[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast) - dot_a_f8 = f8e4m3fn[16,16] convert(dot_a_clamped) - ROOT out = (f8e4m3fn[16,16], f16[]) tuple(dot_a_f8, amax) + dot_a_f8 = <>[16,16] convert(dot_a_clamped) + ROOT out = (<>[16,16], f16[]) tuple(dot_a_f8, amax) } )"; CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[16,32], {{.*}}: f8e4m3fn[32,16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> (f8e4m3fn[16,16], f16[]) { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]) +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> (<>[16,16], f16[]) { +; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(2) ; CHECK-NEXT: [[P2_CONVERT:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(3) ; CHECK-NEXT: [[P3_CONVERT:%[^ ]+]] = f32[] convert([[P3]]) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[C2:%[^ ]+]] = f16[] constant(1) -; CHECK-NEXT: [[P4:%[^ ]+]] = f16[] parameter(4) -; CHECK-NEXT: [[P4_INV:%[^ ]+]] = f16[] divide([[C2]], [[P4]]) -; CHECK-NEXT: [[P4_INV_CONVERT:%[^ ]+]] = f32[] convert([[P4_INV]]) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f8e4m3fn[16,16]{1,0}, f32[]) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]], /*index=5*/[[P4_INV_CONVERT]]), +; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f16[] constant(1) +; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f16[] parameter(4) +; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f16[] divide([[C2]], [[P4]]) +; CHECK-PTX-NEXT: [[P4_INV_CONVERT:%[^ ]+]] = f32[] convert([[P4_INV]]) +; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[]) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]], /*index=5*/[[P4_INV_CONVERT]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]], /*index=5*/[[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6820,6 +7078,11 @@ TEST_P(ParameterizedFp8GemmRewriteTest, #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test @@ -6830,8 +7093,8 @@ TEST_P(ParameterizedFp8GemmRewriteTest, } ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[32,16] convert(y) x_scale = f32[] parameter(2) @@ -6849,32 +7112,33 @@ TEST_P(ParameterizedFp8GemmRewriteTest, c0 = f32[] constant(-inf) amax = f32[] reduce(dot_a_relu, c0), dimensions={0,1}, to_apply=apply dot_a_scaled = f32[16,16] divide(dot_a_relu, z_scale_bcast) - c1 = f32[] constant(-448.) + c1 = f32[] constant(-<>) c1_bcast = f32[16,16] broadcast(c1), dimensions={} - c2 = f32[] constant(448.) + c2 = f32[] constant(<>) c2_bcast = f32[16,16] broadcast(c2), dimensions={} dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast) - dot_a_f8 = f8e4m3fn[16,16] convert(dot_a_clamped) - ROOT out = (f8e4m3fn[16,16], f32[]) tuple(dot_a_f8, amax) + dot_a_f8 = <>[16,16] convert(dot_a_clamped) + ROOT out = (<>[16,16], f32[]) tuple(dot_a_f8, amax) } )"; CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fn[16,32], {{.*}}: f8e4m3fn[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> (f8e4m3fn[16,16], f32[]) { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]) +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> (<>[16,16], f32[]) { +; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f8e4m3fn[16,16]{1,0}, f32[]) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), +; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) +; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) +; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) +; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[]) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6895,15 +7159,20 @@ TEST_P(ParameterizedFp8GemmRewriteTest, } TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDPrecisionF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 - const char* hlo_template = R"( + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + + const char* raw_hlo_template = R"( HloModule test ENTRY test { - x = f8e4m3fn[1600,3200] parameter(0) - y = f8e4m3fn[3200,1600] parameter(1) + x = <>[1600,3200] parameter(0) + y = <>[3200,1600] parameter(1) x_f32 = f32[1600,3200] convert(x) y_f32 = f32[3200,1600] convert(y) x_scale = f32[] parameter(2) @@ -6916,6 +7185,9 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDPrecisionF8) { } )"; + std::string hlo_template = + absl::StrReplaceAll(raw_hlo_template, replacements_); + absl::flat_hash_map replacements; replacements["<>"] = "default"; const auto hlo_text_default = absl::StrReplaceAll(hlo_template, replacements); @@ -6930,6 +7202,11 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8Parameterized) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + std::array, 32> combinations; int i = 0; @@ -6961,12 +7238,12 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8Parameterized) { const char* hlo_template = R"( HloModule test ENTRY test { - x = f8e4m3fn<><> parameter(0) + x = <><><> parameter(0) x_f32 = f32<><> convert(x) x_scale = f32[] parameter(2) x_scale_bcast = f32<> broadcast(x_scale), dimensions={} x_unscaled = f32<> multiply(x_f32, x_scale_bcast) - y = f8e4m3fn<><> parameter(1) + y = <><><> parameter(1) y_f32 = f32<><> convert(y) y_scale = f32[] parameter(3) y_scale_bcast = f32<> broadcast(y_scale), dimensions={} @@ -6987,7 +7264,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8Parameterized) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( ; CHECK: custom_call_target="__cublas$lt$matmul$f8", )"); @@ -6999,6 +7276,11 @@ TEST_P(ParameterizedFp8GemmRewriteTest, #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + // TODO(wenscarl): For batched matmul, not all combinations of A, B and // output layouts get pattern matched successfully to FP8 custom call. Only // a handful of cases are tested here. @@ -7026,13 +7308,13 @@ TEST_P(ParameterizedFp8GemmRewriteTest, const char* hlo_template = R"( HloModule m ENTRY f { - x_q = f8e4m3fn<><> parameter(0) + x_q = <><><> parameter(0) x_scale = f32[] parameter(2) x_scale_broadcast = f32<><> broadcast(x_scale), dimensions={} x_q_convert = f32<><> convert(x_q) x_qdq = f32<><> multiply(x_q_convert, x_scale_broadcast) - y_q = f8e4m3fn<><> parameter(1) + y_q = <><><> parameter(1) y_scale = f32[] parameter(3) y_scale_broadcast = f32<><> broadcast(y_scale), dimensions={} y_q_convert = f32<><> convert(y_q) @@ -7041,6 +7323,7 @@ ENTRY f { ROOT out = f32[2,64,16]<> dot(x_qdq, y_qdq), lhs_batch_dims={0}, lhs_contracting_dims=<>, rhs_batch_dims={0}, rhs_contracting_dims=<> } )"; + for (const auto& combination : combinations) { absl::flat_hash_map replacements; replacements["<>"] = std::get<0>(combination); @@ -7055,7 +7338,7 @@ ENTRY f { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( ; CHECK: custom_call_target="__cublas$lt$matmul$f8", )"); @@ -7066,12 +7349,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8TF32E5M2) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e5m2[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[32,16] convert(y) x_scale = f32[] parameter(2) @@ -7087,7 +7375,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8TF32E5M2) { CheckFp8IfSupported(hlo_text); RunAndFilecheckHloRewrite( - hlo_text, GemmRewriter(CudaHopperOrRocm(), /*f8_rewrite=*/true), + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), R"( ; CHECK: custom_call_target="__cublas$lt$matmul$f8", )"); @@ -7097,6 +7385,11 @@ TEST_P(ParameterizedFp8GemmRewriteTest, FnuzTypeF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + // Test that FNUZ FP8 gemms are not rewritten, as cuBLAS does not support them const char* hlo_text = R"( HloModule test @@ -7115,11 +7408,55 @@ TEST_P(ParameterizedFp8GemmRewriteTest, FnuzTypeF8) { ROOT out = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} } )"; +#if GOOGLE_CUDA TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(CudaHopperOrRocm(), /*f8_rewrite=*/true); + GemmRewriter pass(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_FALSE(changed); +#endif +#if TENSORFLOW_USE_ROCM + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-2, 1e-2})); + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fnuz[16,32], {{.*}}: f8e4m3fnuz[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fnuz[16,32]{1,0} parameter(0) +; CHECK-PTX-NEXT: [[P0_CV:%[^ ]+]] = f32[16,32]{1,0} convert([[P0]]) +; CHECK-PTX-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) +; CHECK-PTX-NEXT: [[P2_B:%[^ ]+]] = f32[16,32]{1,0} broadcast([[P2]]), dimensions={} +; CHECK-PTX-NEXT: [[P0_UNSCALED:%[^ ]+]] = f32[16,32]{1,0} multiply([[P0_CV]], [[P2_B]]) +; CHECK-PTX-NEXT: [[P1:%[^ ]+]] = f8e4m3fnuz[32,16]{1,0} parameter(1) +; CHECK-PTX-NEXT: [[P1_CV:%[^ ]+]] = f32[32,16]{1,0} convert([[P1]]) +; CHECK-PTX-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) +; CHECK-PTX-NEXT: [[P3_B:%[^ ]+]] = f32[32,16]{1,0} broadcast([[P3]]), dimensions={} +; CHECK-PTX-NEXT: [[P1_UNSCALED:%[^ ]+]] = f32[32,16]{1,0} multiply([[P1_CV]], [[P3_B]]) +; CHECK-PTX-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0_UNSCALED]], [[P1_UNSCALED]]), +; CHECK-GCN-NEXT: [[P1:%[^ ]+]] = f8e4m3fnuz[32,16]{1,0} parameter(1) +; CHECK-GCN-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]) +; CHECK-GCN-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) +; CHECK-GCN-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) +; CHECK-GCN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) +; CHECK-PTX: custom_call_target="<>", +; CHECK-GCN: custom_call_target="__cublas$lt$matmul$f8", +; CHECK: backend_config={ +; CHECK-DAG: "alpha_real":1 +; CHECK-DAG: "alpha_imag":0 +; CHECK-DAG: "beta":0 +; CHECK-DAG: "dot_dimension_numbers":{ +; CHECK-DAG: "lhs_contracting_dimensions":["1"] +; CHECK-PTX-DAG: "rhs_contracting_dimensions":["0"] +; CHECK-GCN-DAG: "rhs_contracting_dimensions":["1"] +; CHECK-DAG: "lhs_batch_dimensions":[] +; CHECK-DAG: "rhs_batch_dimensions":[] +; CHECK-DAG: } +; CHECK-DAG: "precision_config":{ +; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] +; CHECK-DAG: } +; CHECK-DAG: "epilogue":"DEFAULT" +; CHECK: } + )"); +#endif } INSTANTIATE_TEST_SUITE_P(Fp8CublasTestsBothLegacyAndLt, diff --git a/xla/stream_executor/device_description.h b/xla/stream_executor/device_description.h index 7b632c46ab49f4..b2ada3ece94f5a 100644 --- a/xla/stream_executor/device_description.h +++ b/xla/stream_executor/device_description.h @@ -171,6 +171,11 @@ class RocmComputeCapability { return absl::c_count(kList, gfx_version()) != 0; } + bool gfx9_mi300() const { + static constexpr absl::string_view kList[] = {"gfx940", "gfx941", "gfx942"}; + return absl::c_count(kList, gfx_version()) != 0; + } + bool navi21() const { return gfx_version() == "gfx1030"; } bool navi31() const { return gfx_version() == "gfx1100"; } @@ -196,6 +201,8 @@ class RocmComputeCapability { bool has_hipblaslt() const { return gfx9_mi200_or_later(); } + bool has_fp8_support() const { return gfx9_mi300(); } + RocmComputeCapabilityProto ToProto() const { RocmComputeCapabilityProto proto; proto.set_gcn_arch_name(gcn_arch_name_); diff --git a/xla/stream_executor/gpu/gpu_blas_lt.cc b/xla/stream_executor/gpu/gpu_blas_lt.cc index 6ef0f19488c17b..17b2562bd5d5ab 100644 --- a/xla/stream_executor/gpu/gpu_blas_lt.cc +++ b/xla/stream_executor/gpu/gpu_blas_lt.cc @@ -47,6 +47,10 @@ absl::StatusOr AsBlasDataType(PrimitiveType dtype) { return DataType::kF8E5M2; case PrimitiveType::F8E4M3FN: return DataType::kF8E4M3FN; + case PrimitiveType::F8E5M2FNUZ: + return DataType::kF8E5M2FNUZ; + case PrimitiveType::F8E4M3FNUZ: + return DataType::kF8E4M3FNUZ; case PrimitiveType::S8: return DataType::kInt8; case PrimitiveType::F16: @@ -76,6 +80,10 @@ absl::StatusOr AsXlaPrimitiveType(DataType dtype) { return PrimitiveType::F8E5M2; case DataType::kF8E4M3FN: return PrimitiveType::F8E4M3FN; + case DataType::kF8E5M2FNUZ: + return PrimitiveType::F8E5M2FNUZ; + case DataType::kF8E4M3FNUZ: + return PrimitiveType::F8E4M3FNUZ; case DataType::kInt8: return PrimitiveType::S8; case DataType::kHalf: @@ -131,9 +139,11 @@ absl::StatusOr GetBlasComputationType( xla::PrimitiveType output_dtype, int64_t compute_precision) { if (algorithm == xla::PrecisionConfig::ALG_UNSET) { switch (output_dtype) { - case PrimitiveType::F8E5M2: // fall-through - case PrimitiveType::F8E4M3FN: // fall-through - case PrimitiveType::F16: // fall-through + case PrimitiveType::F8E5M2: // fall-through + case PrimitiveType::F8E4M3FN: // fall-through + case PrimitiveType::F8E5M2FNUZ: // fall-through + case PrimitiveType::F8E4M3FNUZ: // fall-through + case PrimitiveType::F16: // fall-through case PrimitiveType::BF16: // Accumulate in f32 precision. return ComputationType::kF32; diff --git a/xla/stream_executor/rocm/hip_blas_lt.cc b/xla/stream_executor/rocm/hip_blas_lt.cc index e8d019508c7d9c..f29506050fbdf6 100644 --- a/xla/stream_executor/rocm/hip_blas_lt.cc +++ b/xla/stream_executor/rocm/hip_blas_lt.cc @@ -313,6 +313,30 @@ auto BlasLt::GetMatmulPlan(const gpu::GemmConfig& cfg, Epilogue epilogue) const TF_ASSIGN_OR_RETURN(auto c_desc, MatrixLayout::Create(c_layout)); TF_ASSIGN_OR_RETURN(auto d_desc, MatrixLayout::Create(output_layout)); +#if TF_ROCM_VERSION >= 60000 + // Currently, the default bias data type in hipblasLt is the same with output + // data type for fp8 matmul, which is different from cublasLt. This is a + // workaround to match cublasLt behavior. + if (epilogue == gpu::BlasLt::Epilogue::kBias) { + auto a_dtype = a_desc.type(); + auto b_dtype = b_desc.type(); + + auto bias_dtype = d_desc.type(); + if ((a_dtype == HIP_R_8F_E4M3_FNUZ || a_dtype == HIP_R_8F_E5M2_FNUZ) && + (b_dtype == HIP_R_8F_E4M3_FNUZ || b_dtype == HIP_R_8F_E5M2_FNUZ)) { + auto d_dtype = d_desc.type(); + if (d_dtype == HIP_R_32F) { + bias_dtype = HIP_R_16BF; + } + + if (bias_dtype != d_dtype) { + TF_RETURN_IF_ERROR(SetAttr( + op_desc.get(), HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_dtype)); + } + } + } +#endif // TF_ROCM_VERSION >= 60000 + // std::make_unique won't work with brace initialization in C++17 ;( return std::make_unique(*this, std::move(op_desc), std::move(a_desc), std::move(b_desc), @@ -388,10 +412,27 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( op_desc_.get(), HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias.opaque())); } +#if TF_ROCM_VERSION >= 60000 + if (a_scale != nullptr) { + TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), + HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, + a_scale.opaque())); + } + if (b_scale != nullptr) { + TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), + HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, + b_scale.opaque())); + } + if (c_scale != nullptr || d_scale != nullptr) { + return absl::InternalError( + "hipblaslt does not support c_scale or d_scale."); + } +#else if ((a_scale != nullptr) || (b_scale != nullptr) || (c_scale != nullptr) || (d_scale != nullptr)) { return absl::InternalError("hipblaslt does not support scale"); } +#endif if (d_amax != nullptr) { return absl::InternalError("hipblaslt does not support amax"); @@ -430,6 +471,17 @@ namespace { template struct HipToNativeT; +#if TF_ROCM_VERSION >= 60000 +template <> +struct HipToNativeT { + using type = tsl::float8_e4m3fnuz; +}; +template <> +struct HipToNativeT { + using type = tsl::float8_e5m2fnuz; +}; +#endif // TF_ROCM_VERSION >= 60000 + template <> struct HipToNativeT { using type = Eigen::bfloat16; @@ -481,6 +533,23 @@ absl::Status BlasLt::MatmulPlan::ExecuteOnStream( profile_result); \ } +#if TF_ROCM_VERSION >= 60000 + TYPED_MATMUL(float, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, HIP_R_16F, + HIP_R_16F) + TYPED_MATMUL(float, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, HIP_R_32F, + HIP_R_32F) + + TYPED_MATMUL(float, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E5M2_FNUZ, HIP_R_16F, + HIP_R_16F) + TYPED_MATMUL(float, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E5M2_FNUZ, HIP_R_32F, + HIP_R_32F) + + TYPED_MATMUL(float, HIP_R_8F_E5M2_FNUZ, HIP_R_8F_E4M3_FNUZ, HIP_R_16F, + HIP_R_16F) + TYPED_MATMUL(float, HIP_R_8F_E5M2_FNUZ, HIP_R_8F_E4M3_FNUZ, HIP_R_32F, + HIP_R_32F) +#endif + // Other data types: TYPED_MATMUL(float, HIP_R_16BF, HIP_R_16BF, HIP_R_16BF, HIP_R_16BF) TYPED_MATMUL(float, HIP_R_16F, HIP_R_16F, HIP_R_16F, HIP_R_16F) diff --git a/xla/stream_executor/rocm/hip_blas_utils.cc b/xla/stream_executor/rocm/hip_blas_utils.cc index 6f4325001a6d32..a59c935614cd8f 100644 --- a/xla/stream_executor/rocm/hip_blas_utils.cc +++ b/xla/stream_executor/rocm/hip_blas_utils.cc @@ -36,7 +36,17 @@ hipDataType AsHipblasDataType(blas::DataType type) { switch (type) { case blas::DataType::kF8E5M2: case blas::DataType::kF8E4M3FN: - LOG(FATAL) << "hipblaslt does not support F8 yet"; + LOG(FATAL) << "hipblaslt does not support F8E5M2 and F8E4M3FN"; +#if TF_ROCM_VERSION >= 60000 + case blas::DataType::kF8E5M2FNUZ: + return HIP_R_8F_E5M2_FNUZ; + case blas::DataType::kF8E4M3FNUZ: + return HIP_R_8F_E4M3_FNUZ; +#else + case blas::DataType::kF8E5M2FNUZ: + case blas::DataType::kF8E4M3FNUZ: + LOG(FATAL) << "hipblaslt only supports F8 in ROCm 6.0 and above"; +#endif case blas::DataType::kHalf: return HIP_R_16F; case blas::DataType::kBF16: diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 76dafa933a4847..0202155e123287 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -331,6 +331,7 @@ cc_library( data = [ "@llvm-project//llvm:FileCheck", ], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), deps = [ "//xla:statusor", "//xla:types", diff --git a/xla/tests/dot_operation_test.cc b/xla/tests/dot_operation_test.cc index 614639c222fef7..6ed8d128ee277f 100644 --- a/xla/tests/dot_operation_test.cc +++ b/xla/tests/dot_operation_test.cc @@ -70,7 +70,12 @@ using TypesF16F32F64CF64 = ::testing::Types< #endif float>; +#if GOOGLE_CUDA using TypesF8 = ::testing::Types; +#endif +#if TF_HIPBLASLT && TF_ROCM_VERSION >= 60000 +using TypesF8 = ::testing::Types; +#endif // Check that we can safely pass an input tuple's elements to a dot operation. XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) { @@ -731,7 +736,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) { {x_data.get(), y_data.get()}, this->error_spec_); } -#if GOOGLE_CUDA || TF_HIPBLASLT +#if GOOGLE_CUDA || (TF_HIPBLASLT && TF_ROCM_VERSION >= 60000) template class DotOperationTestWithCublasLt_F16F32F64CF64 : public DotOperationTest { public: @@ -787,7 +792,7 @@ XLA_TYPED_TEST(DotOperationTestWithCublasLt_F16F32F64CF64, } #endif // GOOGLE_CUDA || TF_HIPBLASLT -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TF_HIPBLASLT template class DotOperationTestWithCublasLt_F8 : public DotOperationTest { public: @@ -1109,7 +1114,7 @@ XLA_TYPED_TEST(DotOperationTestWithCublasLt_F8, ScaledABScaledDWithDAmaxF8) { b_scale_data.get(), d_scale_data.get()}, this->error_spec_); } -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TF_HIPBLASLT XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulR3LhsR2Rhs) { using T = TypeParam; diff --git a/xla/tests/filecheck.cc b/xla/tests/filecheck.cc index 9b8188a166520f..2031ec3d2b26ef 100644 --- a/xla/tests/filecheck.cc +++ b/xla/tests/filecheck.cc @@ -50,9 +50,23 @@ absl::StatusOr RunFileCheckWithPatternFile( : tsl::io::JoinPath("llvm", "llvm-project", "llvm", "FileCheck")); tsl::SubProcess file_check_process; +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + std::string file_check_prefixes; +#if GOOGLE_CUDA + file_check_prefixes = "--check-prefixes=CHECK,CHECK-PTX"; +#endif // GOOGLE_CUDA +#if TENSORFLOW_USE_ROCM + file_check_prefixes = "--check-prefixes=CHECK,CHECK-GCN"; +#endif // TENSORFLOW_USE_ROCM + file_check_process.SetProgram( + file_check_path, + {file_check_path, "-v", "-dump-input=fail", "--dump-input-filter=all", + file_check_prefixes, "--allow-unused-prefixes", pattern_file}); +#else // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM) file_check_process.SetProgram(file_check_path, {file_check_path, "-v", "-dump-input=fail", "--dump-input-filter=all", pattern_file}); +#endif file_check_process.SetChannelAction(tsl::CHAN_STDIN, tsl::ACTION_PIPE); file_check_process.SetChannelAction(tsl::CHAN_STDERR, tsl::ACTION_PIPE); if (!file_check_process.Start()) {