From 50c225d573f4d67c624bf13973c4e0d35bb6546b Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Wed, 25 Sep 2024 08:31:16 -0700 Subject: [PATCH] [XLA] Ensure that the operands of rng bit generator are replicated since the spmd partitioner will replicate it anyway. PiperOrigin-RevId: 678712639 --- xla/service/sharding_propagation.cc | 4 +++ xla/service/sharding_propagation_test.cc | 33 ++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/xla/service/sharding_propagation.cc b/xla/service/sharding_propagation.cc index bf44306a1cccd0..2c58d92a9d2a07 100644 --- a/xla/service/sharding_propagation.cc +++ b/xla/service/sharding_propagation.cc @@ -2714,6 +2714,10 @@ bool ShardingPropagation::InferShardingFromUsers( bool improved_sharding = false; const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0; for (const HloInstruction* user : instruction->users()) { + if (user->opcode() == HloOpcode::kRngBitGenerator) { + instruction->set_sharding(HloSharding::Replicate()); + return true; + } std::optional user_sharding = ShardingPropagation::GetShardingFromUser(*instruction, *user, aggressiveness, is_spmd, diff --git a/xla/service/sharding_propagation_test.cc b/xla/service/sharding_propagation_test.cc index 5ca4b47d8ea15c..96303a8d1b6880 100644 --- a/xla/service/sharding_propagation_test.cc +++ b/xla/service/sharding_propagation_test.cc @@ -12195,5 +12195,38 @@ ENTRY main { op::Sharding("{devices=[2,1,2]<=[4] last_tile_dim_replicate}")); } +TEST_F(ShardingPropagationTest, ReplicateRngBitGeneratorSeed) { + const char* const hlo_string = R"( +HloModule module +apply_or { + x = u64[] parameter(0) + y = u64[] parameter(1) + ROOT x_or_y = or(x, y) +} +ENTRY main { + p = s32[2,2]{1,0} parameter(0), sharding={devices=[2,2]<=[4]} + up = u64[2,2] convert(p) + i = u64[] constant(0) + seed = u64[2] reduce(up, i), dimensions={1}, to_apply=apply_or + rbg = u32[2048,4096] rng-bit-generator(seed), algorithm=rng_default + ROOT s = u32[2048,4096]{1,0} custom-call(rbg), custom_call_target="Sharding", sharding={devices=[2,2]<=[4]} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}, + /*allow_spmd_sharding_propagation_to_parameters=*/{true}) + .Run(module.get())); + EXPECT_TRUE(changed); + + XLA_VLOG_LINES(1, module->ToString()); + auto* instruction = FindInstruction(module.get(), "seed"); + // Check sharding is correctly propagated. + EXPECT_TRUE(instruction->sharding().IsReplicated()); +} + } // namespace } // namespace xla