Skip to content

Commit

Permalink
[XLA] Ensure that the operands of rng bit generator are replicated si…
Browse files Browse the repository at this point in the history
…nce the

spmd partitioner will replicate it anyway.

PiperOrigin-RevId: 678712639
  • Loading branch information
blakehechtman authored and Google-ML-Automation committed Sep 28, 2024
1 parent 3c5c920 commit 50c225d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
4 changes: 4 additions & 0 deletions xla/service/sharding_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2714,6 +2714,10 @@ bool ShardingPropagation::InferShardingFromUsers(
bool improved_sharding = false;
const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0;
for (const HloInstruction* user : instruction->users()) {
if (user->opcode() == HloOpcode::kRngBitGenerator) {
instruction->set_sharding(HloSharding::Replicate());
return true;
}
std::optional<HloSharding> user_sharding =
ShardingPropagation::GetShardingFromUser(*instruction, *user,
aggressiveness, is_spmd,
Expand Down
33 changes: 33 additions & 0 deletions xla/service/sharding_propagation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12195,5 +12195,38 @@ ENTRY main {
op::Sharding("{devices=[2,1,2]<=[4] last_tile_dim_replicate}"));
}

TEST_F(ShardingPropagationTest, ReplicateRngBitGeneratorSeed) {
const char* const hlo_string = R"(
HloModule module
apply_or {
x = u64[] parameter(0)
y = u64[] parameter(1)
ROOT x_or_y = or(x, y)
}
ENTRY main {
p = s32[2,2]{1,0} parameter(0), sharding={devices=[2,2]<=[4]}
up = u64[2,2] convert(p)
i = u64[] constant(0)
seed = u64[2] reduce(up, i), dimensions={1}, to_apply=apply_or
rbg = u32[2048,4096] rng-bit-generator(seed), algorithm=rng_default
ROOT s = u32[2048,4096]{1,0} custom-call(rbg), custom_call_target="Sharding", sharding={devices=[2,2]<=[4]}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
bool changed,
ShardingPropagation(
/*is_spmd=*/true, /*propagate_metadata=*/true,
/*allow_spmd_sharding_propagation_to_output=*/{true},
/*allow_spmd_sharding_propagation_to_parameters=*/{true})
.Run(module.get()));
EXPECT_TRUE(changed);

XLA_VLOG_LINES(1, module->ToString());
auto* instruction = FindInstruction(module.get(), "seed");
// Check sharding is correctly propagated.
EXPECT_TRUE(instruction->sharding().IsReplicated());
}

} // namespace
} // namespace xla

0 comments on commit 50c225d

Please sign in to comment.