Skip to content

Commit

Permalink
add an option to disable all async collectives using a single op
Browse files Browse the repository at this point in the history
  • Loading branch information
Tixxx committed Feb 4, 2025
1 parent 8eb1817 commit d561e61
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 9 deletions.
9 changes: 9 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,15 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
if (absl::c_all_of(values, is_collective_type)) {
debug_options->clear_xla_gpu_disable_async_collectives();
for (const absl::string_view value : values) {
auto parsed_op = parse_collective_type(value);
if (parsed_op == DebugOptions::ALLCOLLECTIVES) {
for (int i = (int)DebugOptions::ALLREDUCE;
i < (int)DebugOptions::ALLCOLLECTIVES; i++) {
debug_options->add_xla_gpu_disable_async_collectives(
(DebugOptions::CollectiveOpType)i);
}
return true;
}
debug_options->add_xla_gpu_disable_async_collectives(
parse_collective_type(value));
}
Expand Down
70 changes: 61 additions & 9 deletions xla/tests/collective_ops_e2e_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ class CollectiveOpsTestE2E : public HloTestBase {
num_replicas, /*run_hlo_passes=*/false, &device_assignment);
}

bool IsAsync(const HloInstruction* inst) {
return !inst->backend_config<gpu::GpuBackendConfig>()
.value()
.collective_backend_config()
.is_sync();
}

protected:
absl::flat_hash_map<absl::string_view, absl::string_view> replacements_;

Expand Down Expand Up @@ -181,16 +188,7 @@ class AsyncCollectiveOps : public CollectiveOpsTestE2E,
return CreateExecutable(std::move(module),
/*run_hlo_passes=*/true);
}

using CollectiveOpsTestE2E::CreateExecutable;

bool IsAsync(const HloInstruction* inst) {
return !inst->backend_config<gpu::GpuBackendConfig>()
.value()
.collective_backend_config()
.is_sync();
}

const int64_t num_devices_;
};

Expand Down Expand Up @@ -974,6 +972,60 @@ TEST_F(CollectiveOpsTestE2E, NoAllToAllDecomposition) {
LiteralTestUtil::ExpectR1Equal<uint32_t>({20, 25, 21, 26}, results[1]);
}

// Verify that collectives won't be transformed into async ones.
// This needs to have:
// --test_env=XLA_FLAGS="--xla_gpu_disable_async_collectives=ALLCOLLECTIVES"
// set in the bazel command.
TEST_F(CollectiveOpsTestE2E, NoAsyncCollectives) {
const absl::string_view kModuleStr = R"(
HloModule test
apply_op {
x = u32[] parameter(0)
y = u32[] parameter(1)
ROOT apply_op = u32[] add(x, y)
}
ENTRY test_computation {
id = u32[] replica-id()
id2 = u32[2, 2] broadcast(id), dimensions={}
a0 = u32[2, 2] constant({{10, 15}, {20, 25}})
a1 = u32[2, 2] add(id2, a0)
all2all = u32[2, 2] all-to-all(a1), replica_groups={{0,1}}, dimensions={0}
ROOT ag = u32[2, 2] all-reduce(all2all), replica_groups={{0,1}}, to_apply=apply_op
}
)";
const int64_t kNumReplicas = 2;
if (test_runner().device_count() < kNumReplicas) {
GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices ("
<< test_runner().device_count() << " available)";
}

HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
config.mutable_debug_options().add_xla_disable_hlo_passes(
"gpu-convert-async-collectives-to-sync");

TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));

TF_ASSERT_OK_AND_ASSIGN(
auto executable,
CreateExecutable(std::move(module), /*run_hlo_passes=*/true));
ASSERT_TRUE(executable->has_module());
HloModule* executable_module = &executable->module();

// Verify that the all-to-all is a sync collective.
const HloInstruction* all_to_all =
FindInstruction(executable_module, HloOpcode::kAsyncStart);
EXPECT_FALSE(IsAsync(all_to_all));

// Verify that the all-reduce is a sync collective.
const HloInstruction* all_reduce =
FindInstruction(executable_module, HloOpcode::kAllReduceStart);

EXPECT_FALSE(IsAsync(all_reduce));
}

// E2E tests comparing the results of windowed einsum and non-windowed cases.
class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E {
public:
Expand Down
2 changes: 2 additions & 0 deletions xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ message DebugOptions {
ALLTOALL = 5;
COLLECTIVEPERMUTE = 6;
RAGGEDALLTOALL = 7;
// Add more collectives before ALL
ALLCOLLECTIVES = 8;
}

// Commands are categorized into 5 types:
Expand Down

0 comments on commit d561e61

Please sign in to comment.