Skip to content

Commit

Permalink
[XLA:GPU] Remove xla_gpu_enable_triton_gemm_int4 flag which is on by …
Browse files Browse the repository at this point in the history
…default.

This flag has been enabled by default for a while now, and there is no reason to keep it around.

PiperOrigin-RevId: 678620682
  • Loading branch information
loislo authored and Google-ML-Automation committed Sep 28, 2024
1 parent 0f30f33 commit 2f7ce18
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 50 deletions.
10 changes: 3 additions & 7 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {

opts.set_xla_gpu_cudnn_gemm_max_plans(5);

opts.set_xla_gpu_enable_triton_gemm_int4(true);

opts.set_xla_gpu_enable_pgle_accuracy_checker(false);

opts.set_xla_gpu_executable_warn_stuck_timeout_seconds(10);
Expand Down Expand Up @@ -1923,11 +1921,9 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"Limit for the number of kernel configurations (plans) to use during "
"autotuning of cuDNN GEMM fusions."));

flag_list->push_back(tsl::Flag(
"xla_gpu_enable_triton_gemm_int4",
bool_setter_for(&DebugOptions::set_xla_gpu_enable_triton_gemm_int4),
debug_options->xla_gpu_enable_triton_gemm_int4(),
"Experimental: Enable Triton gemm for int4 inputs."));
flag_list->push_back(tsl::Flag("xla_gpu_enable_triton_gemm_int4",
noop_flag_setter<bool>, true,
"[Deprecated, do not use]"));
flag_list->push_back(
tsl::Flag("xla_gpu_async_dot",
bool_setter_for(&DebugOptions::set_xla_gpu_async_dot),
Expand Down
18 changes: 0 additions & 18 deletions xla/service/gpu/fusions/triton/triton_fusion_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1216,14 +1216,6 @@ absl::StatusOr<Value> EmitScope(
Value result;
if (hlo->opcode() == HloOpcode::kConvert &&
hlo->operand(0)->shape().element_type() == S4) {
if (!hlo->GetModule()
->config()
.debug_options()
.xla_gpu_enable_triton_gemm_int4()) {
return absl::UnimplementedError(
"Int4 support is not enabled in the debug options.");
}

TF_ASSIGN_OR_RETURN(
auto unpacked, EmitUnpackInt4(b, hlo, side, values[hlo->operand(0)]));
std::vector<Value> operands({unpacked});
Expand Down Expand Up @@ -3058,15 +3050,6 @@ absl::Status CreateInternalError(std::string_view message,
return absl::InternalError(err);
}

absl::Status DoSupportType(const DebugOptions& debug_options,
PrimitiveType type) {
if (type == S4 && !debug_options.xla_gpu_enable_triton_gemm_int4()) {
return absl::FailedPreconditionError(
"Int4 support is not enabled in the debug options.");
}
return absl::OkStatus();
}

absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(
absl::string_view fn_name, const HloFusionInstruction* fusion,
const se::DeviceDescription& device_info,
Expand All @@ -3088,7 +3071,6 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(
SmallVector<Type> fn_arg_types;
for (HloInstruction* p : hlo_computation->parameter_instructions()) {
PrimitiveType type = p->shape().element_type();
TF_RETURN_IF_ERROR(DoSupportType(debug_options, type));
Type ir_type;
if (type == U16) {
ir_type = b.getI16Type();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ class TritonGemmTest : public TritonTest {
debug_options.set_xla_gpu_enable_split_k_autotuning(false);
// Always rewrite Gemms with Triton regardless of size.
debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0);
debug_options.set_xla_gpu_enable_triton_gemm_int4(true);
return debug_options;
}

Expand Down
9 changes: 1 addition & 8 deletions xla/service/gpu/fusions/triton/triton_support_legacy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,7 @@ CodegenDecision IsInstructionSupportsDataTypes(
const auto operand_type = operand->shape().element_type();
switch (instr.opcode()) {
case HloOpcode::kConvert:
// TODO(b/358580281): remove DebugOptions from this function after
// enabling int4 in Triton GEMM.
if (operand_type == S4 && instr.GetModule()
->config()
.debug_options()
.xla_gpu_enable_triton_gemm_int4()) {
continue;
}
if (operand_type == S4) continue;
[[fallthrough]];
default:
if (!IsTritonSupportedDataType(operand_type, gpu_version)) {
Expand Down
12 changes: 0 additions & 12 deletions xla/service/gpu/transforms/gemm_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1364,9 +1364,6 @@ TEST_F(SmallDotGemmFusionTest, Int4DotIsRewritten) {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(kInt4Dot));
module->mutable_config()
.mutable_debug_options()
.set_xla_gpu_enable_triton_gemm_int4(true);
EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
}

Expand All @@ -1384,9 +1381,6 @@ TEST_F(SmallDotGemmFusionTest, Int4ConcatPlusConvertIsRewritten) {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(kInt4Dot));
module->mutable_config()
.mutable_debug_options()
.set_xla_gpu_enable_triton_gemm_int4(true);
EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());

// Check that the fusion is present and that the lhs is not converted.
Expand All @@ -1411,9 +1405,6 @@ TEST_F(SmallDotGemmFusionTest, Int4ConvertPlusNegateIsRewritten) {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(kInt4Dot));
module->mutable_config()
.mutable_debug_options()
.set_xla_gpu_enable_triton_gemm_int4(true);
EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
// Check that the fusion is present and that convert and negation is fused in
// it.
Expand All @@ -1440,9 +1431,6 @@ TEST_F(SmallDotGemmFusionTest, Int4WithMinorBatchDimIsNotRewritten) {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(kInt4Dot));
module->mutable_config()
.mutable_debug_options()
.set_xla_gpu_enable_triton_gemm_int4(true);
TF_ASSERT_OK_AND_ASSIGN(auto result,
GemmFusion(gpu_version_).Run(module.get()));
EXPECT_FALSE(result);
Expand Down
6 changes: 2 additions & 4 deletions xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -949,9 +949,6 @@ message DebugOptions {
// If enabled, uses the libnvjitlink library for PTX compilation and linking
bool xla_gpu_enable_libnvjitlink = 319;

// If enabled, generates triton gemm kernels for int4 inputs.
bool xla_gpu_enable_triton_gemm_int4 = 320;

// If true, XLA will wrap `dot` operations into async computations in an
// effort to parallelize matrix operations.
bool xla_gpu_async_dot = 321;
Expand Down Expand Up @@ -1005,7 +1002,8 @@ message DebugOptions {
// xla_gpu_graph_level
// xla_gpu_single_wave_autotuning
// xla_gpu_enable_persistent_temp_buffers
reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194, 242, 206;
// xla_gpu_enable_triton_gemm_int4
reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194, 242, 206, 320;
}

// Contains flags which affects the GPU compilation result.
Expand Down

0 comments on commit 2f7ce18

Please sign in to comment.