From d247af5cd33fe42698bb55ef1c18f32df8a02a21 Mon Sep 17 00:00:00 2001 From: scxfjiang Date: Mon, 9 Sep 2024 06:15:24 -0500 Subject: [PATCH] refactor tests for collective comm ops --- xla/tests/collective_ops_test.cc | 249 +++++++++++++------------------ 1 file changed, 102 insertions(+), 147 deletions(-) diff --git a/xla/tests/collective_ops_test.cc b/xla/tests/collective_ops_test.cc index 265d2a2ce5cfe..4524cf9b2efb9 100644 --- a/xla/tests/collective_ops_test.cc +++ b/xla/tests/collective_ops_test.cc @@ -1753,153 +1753,6 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceBFloat16Min) { } } -XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - a0 = f8e4m3fn[1,2] constant({{1,2}}) - allgather = f8e4m3fn[2, 2] all-gather(a0), dimensions={0} - p = f8e4m3fn[4] reshape(allgather) - ROOT out = f32[4] convert(p) - } - )"; - const int64_t kNumReplicas = 2; - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::vector results, - ExecuteReplicated(std::move(module), absl::Span{}, - kNumReplicas, - /*use_threads=*/true, /*run_hlo_passes=*/true)); - ASSERT_EQ(results.size(), kNumReplicas); - for (const Literal& result : results) { - LiteralTestUtil::ExpectR1Equal({1, 2, 1, 2}, result); - } -} - -XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - a0 = f8e4m3fn[2] constant({1,2}) - a2a = f8e4m3fn[2] all-to-all(a0), dimensions={0} - ROOT out = f32[2] convert(a2a) - } - )"; - const int64_t kNumReplicas = 2; - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::vector results, - ExecuteReplicated(std::move(module), absl::Span{}, - kNumReplicas, - /*use_threads=*/true, /*run_hlo_passes=*/true)); - ASSERT_EQ(results.size(), kNumReplicas); - LiteralTestUtil::ExpectR1Equal({1, 1}, results[0]); - LiteralTestUtil::ExpectR1Equal({2, 2}, results[1]); -} - -XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_8BitFloat)) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - a0 = f8e5m2[2] constant({1,2}) - a1 = f8e5m2[2] collective-permute(a0), source_target_pairs={{0,1}, {1,0}} - ROOT out = f32[2] convert(a1) - } - )"; - const int64_t kNumReplicas = 2; - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::vector results, - ExecuteReplicated(std::move(module), absl::Span{}, - kNumReplicas, - /*use_threads=*/true, /*run_hlo_passes=*/true)); - ASSERT_EQ(results.size(), kNumReplicas); - LiteralTestUtil::ExpectR1Equal({1, 2}, results[0]); - LiteralTestUtil::ExpectR1Equal({1, 2}, results[1]); -} - -// TODO: Refactor the test to reduce the duplicate code for OCP fp8 and Nanoo fp8 -XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat_Nanoo)) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - a0 = f8e4m3fnuz[1,2] constant({{1,2}}) - allgather = f8e4m3fnuz[2, 2] all-gather(a0), dimensions={0} - p = f8e4m3fnuz[4] reshape(allgather) - ROOT out = f32[4] convert(p) - } - )"; - const int64_t kNumReplicas = 2; - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, - /*use_threads=*/true, /*run_hlo_passes=*/true)); - ASSERT_EQ(results.size(), kNumReplicas); - for (const Literal& result : results) { - LiteralTestUtil::ExpectR1Equal({1, 2, 1, 2}, result); - } -} - -XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat_Nanoo)) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - a0 = f8e4m3fnuz[2] constant({1,2}) - a2a = f8e4m3fnuz[2] all-to-all(a0), dimensions={0} - ROOT out = f32[2] convert(a2a) - } - )"; - const int64_t kNumReplicas = 2; - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, - /*use_threads=*/true, /*run_hlo_passes=*/true)); - ASSERT_EQ(results.size(), kNumReplicas); - LiteralTestUtil::ExpectR1Equal({1, 1}, results[0]); - LiteralTestUtil::ExpectR1Equal({2, 2}, results[1]); -} - -XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_8BitFloat_Nanoo)) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - a0 = f8e5m2fnuz[2] constant({1,2}) - a1 = f8e5m2fnuz[2] collective-permute(a0), source_target_pairs={{0,1}, {1,0}} - ROOT out = f32[2] convert(a1) - } - )"; - const int64_t kNumReplicas = 2; - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, - /*use_threads=*/true, /*run_hlo_passes=*/true)); - ASSERT_EQ(results.size(), kNumReplicas); - LiteralTestUtil::ExpectR1Equal({1, 2}, results[0]); - LiteralTestUtil::ExpectR1Equal({1, 2}, results[1]); -} - - XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllGather)) { const char* const kModuleStr = R"( HloModule test @@ -2346,5 +2199,107 @@ body { results[1])); } +class Fp8CollectiveOpsTest : public CollectiveOpsTest { + public: + Fp8CollectiveOpsTest() { + replacements_[kF8E4M3DatatypePlaceholder] = +#if GOOGLE_CUDA + "f8e4m3fn"; +#else + "f8e4m3fnuz"; +#endif + replacements_[kF8E5M2DatatypePlaceholder] = +#if GOOGLE_CUDA + "f8e5m2"; +#else + "f8e5m2fnuz"; +#endif + } + + protected: + absl::flat_hash_map replacements_; + + private: + static constexpr const char* kF8E4M3DatatypePlaceholder{"<>"}; + static constexpr const char* kF8E5M2DatatypePlaceholder{"<>"}; +}; + +XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + a0 = <>[1,2] constant({{1,2}}) + allgather = <>[2, 2] all-gather(a0), dimensions={0} + p = <>[4] reshape(allgather) + ROOT out = f32[4] convert(p) + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule( + absl::StrReplaceAll(kModuleStr, replacements_), config)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + for (const Literal& result : results) { + LiteralTestUtil::ExpectR1Equal({1, 2, 1, 2}, result); + } +} + +XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + a0 = <>[2] constant({1,2}) + a2a = <>[2] all-to-all(a0), dimensions={0} + ROOT out = f32[2] convert(a2a) + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule( + absl::StrReplaceAll(kModuleStr, replacements_), config)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + LiteralTestUtil::ExpectR1Equal({1, 1}, results[0]); + LiteralTestUtil::ExpectR1Equal({2, 2}, results[1]); +} + +XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_8BitFloat)) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + a0 = <>[2] constant({1,2}) + a1 = <>[2] collective-permute(a0), source_target_pairs={{0,1}, {1,0}} + ROOT out = f32[2] convert(a1) + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule( + absl::StrReplaceAll(kModuleStr, replacements_), config)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + LiteralTestUtil::ExpectR1Equal({1, 2}, results[0]); + LiteralTestUtil::ExpectR1Equal({1, 2}, results[1]); +} + } // namespace } // namespace xla