From 2f7ce18e23045defdd83d4ca0b508760ad75aede Mon Sep 17 00:00:00 2001 From: Ilya Tikhonovskiy Date: Wed, 25 Sep 2024 03:14:23 -0700 Subject: [PATCH] [XLA:GPU] Remove xla_gpu_enable_triton_gemm_int4 flag which is on by default. This flag has been enabled by default for a while now, and there is no reason to keep it around. PiperOrigin-RevId: 678620682 --- xla/debug_options_flags.cc | 10 +++------- .../fusions/triton/triton_fusion_emitter.cc | 18 ------------------ ...triton_fusion_emitter_device_legacy_test.cc | 1 - .../fusions/triton/triton_support_legacy.cc | 9 +-------- xla/service/gpu/transforms/gemm_fusion_test.cc | 12 ------------ xla/xla.proto | 6 ++---- 6 files changed, 6 insertions(+), 50 deletions(-) diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index abedfc370dd83..267029c063067 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -285,8 +285,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_cudnn_gemm_max_plans(5); - opts.set_xla_gpu_enable_triton_gemm_int4(true); - opts.set_xla_gpu_enable_pgle_accuracy_checker(false); opts.set_xla_gpu_executable_warn_stuck_timeout_seconds(10); @@ -1923,11 +1921,9 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "Limit for the number of kernel configurations (plans) to use during " "autotuning of cuDNN GEMM fusions.")); - flag_list->push_back(tsl::Flag( - "xla_gpu_enable_triton_gemm_int4", - bool_setter_for(&DebugOptions::set_xla_gpu_enable_triton_gemm_int4), - debug_options->xla_gpu_enable_triton_gemm_int4(), - "Experimental: Enable Triton gemm for int4 inputs.")); + flag_list->push_back(tsl::Flag("xla_gpu_enable_triton_gemm_int4", + noop_flag_setter, true, + "[Deprecated, do not use]")); flag_list->push_back( tsl::Flag("xla_gpu_async_dot", bool_setter_for(&DebugOptions::set_xla_gpu_async_dot), diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index a105789fa86f0..ccaad403bd954 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -1216,14 +1216,6 @@ absl::StatusOr EmitScope( Value result; if (hlo->opcode() == HloOpcode::kConvert && hlo->operand(0)->shape().element_type() == S4) { - if (!hlo->GetModule() - ->config() - .debug_options() - .xla_gpu_enable_triton_gemm_int4()) { - return absl::UnimplementedError( - "Int4 support is not enabled in the debug options."); - } - TF_ASSIGN_OR_RETURN( auto unpacked, EmitUnpackInt4(b, hlo, side, values[hlo->operand(0)])); std::vector operands({unpacked}); @@ -3058,15 +3050,6 @@ absl::Status CreateInternalError(std::string_view message, return absl::InternalError(err); } -absl::Status DoSupportType(const DebugOptions& debug_options, - PrimitiveType type) { - if (type == S4 && !debug_options.xla_gpu_enable_triton_gemm_int4()) { - return absl::FailedPreconditionError( - "Int4 support is not enabled in the debug options."); - } - return absl::OkStatus(); -} - absl::StatusOr> CreateTritonModule( absl::string_view fn_name, const HloFusionInstruction* fusion, const se::DeviceDescription& device_info, @@ -3088,7 +3071,6 @@ absl::StatusOr> CreateTritonModule( SmallVector fn_arg_types; for (HloInstruction* p : hlo_computation->parameter_instructions()) { PrimitiveType type = p->shape().element_type(); - TF_RETURN_IF_ERROR(DoSupportType(debug_options, type)); Type ir_type; if (type == U16) { ir_type = b.getI16Type(); diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc index 2d5c423203889..036e3221c71bc 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc @@ -119,7 +119,6 @@ class TritonGemmTest : public TritonTest { debug_options.set_xla_gpu_enable_split_k_autotuning(false); // Always rewrite Gemms with Triton regardless of size. debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0); - debug_options.set_xla_gpu_enable_triton_gemm_int4(true); return debug_options; } diff --git a/xla/service/gpu/fusions/triton/triton_support_legacy.cc b/xla/service/gpu/fusions/triton/triton_support_legacy.cc index 802fed51f4d20..97c4891441d98 100644 --- a/xla/service/gpu/fusions/triton/triton_support_legacy.cc +++ b/xla/service/gpu/fusions/triton/triton_support_legacy.cc @@ -122,14 +122,7 @@ CodegenDecision IsInstructionSupportsDataTypes( const auto operand_type = operand->shape().element_type(); switch (instr.opcode()) { case HloOpcode::kConvert: - // TODO(b/358580281): remove DebugOptions from this function after - // enabling int4 in Triton GEMM. - if (operand_type == S4 && instr.GetModule() - ->config() - .debug_options() - .xla_gpu_enable_triton_gemm_int4()) { - continue; - } + if (operand_type == S4) continue; [[fallthrough]]; default: if (!IsTritonSupportedDataType(operand_type, gpu_version)) { diff --git a/xla/service/gpu/transforms/gemm_fusion_test.cc b/xla/service/gpu/transforms/gemm_fusion_test.cc index af97932ddf3a8..54985658367d4 100644 --- a/xla/service/gpu/transforms/gemm_fusion_test.cc +++ b/xla/service/gpu/transforms/gemm_fusion_test.cc @@ -1364,9 +1364,6 @@ TEST_F(SmallDotGemmFusionTest, Int4DotIsRewritten) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kInt4Dot)); - module->mutable_config() - .mutable_debug_options() - .set_xla_gpu_enable_triton_gemm_int4(true); EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); } @@ -1384,9 +1381,6 @@ TEST_F(SmallDotGemmFusionTest, Int4ConcatPlusConvertIsRewritten) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kInt4Dot)); - module->mutable_config() - .mutable_debug_options() - .set_xla_gpu_enable_triton_gemm_int4(true); EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); // Check that the fusion is present and that the lhs is not converted. @@ -1411,9 +1405,6 @@ TEST_F(SmallDotGemmFusionTest, Int4ConvertPlusNegateIsRewritten) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kInt4Dot)); - module->mutable_config() - .mutable_debug_options() - .set_xla_gpu_enable_triton_gemm_int4(true); EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); // Check that the fusion is present and that convert and negation is fused in // it. @@ -1440,9 +1431,6 @@ TEST_F(SmallDotGemmFusionTest, Int4WithMinorBatchDimIsNotRewritten) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kInt4Dot)); - module->mutable_config() - .mutable_debug_options() - .set_xla_gpu_enable_triton_gemm_int4(true); TF_ASSERT_OK_AND_ASSIGN(auto result, GemmFusion(gpu_version_).Run(module.get())); EXPECT_FALSE(result); diff --git a/xla/xla.proto b/xla/xla.proto index c8900a65753b3..f4ef137fce057 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -949,9 +949,6 @@ message DebugOptions { // If enabled, uses the libnvjitlink library for PTX compilation and linking bool xla_gpu_enable_libnvjitlink = 319; - // If enabled, generates triton gemm kernels for int4 inputs. - bool xla_gpu_enable_triton_gemm_int4 = 320; - // If true, XLA will wrap `dot` operations into async computations in an // effort to parallelize matrix operations. bool xla_gpu_async_dot = 321; @@ -1005,7 +1002,8 @@ message DebugOptions { // xla_gpu_graph_level // xla_gpu_single_wave_autotuning // xla_gpu_enable_persistent_temp_buffers - reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194, 242, 206; + // xla_gpu_enable_triton_gemm_int4 + reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194, 242, 206, 320; } // Contains flags which affects the GPU compilation result.