From b301c2a2a5b52180f9e9626173e6b67a78782960 Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Mon, 9 Oct 2023 21:46:24 +0000 Subject: [PATCH] Change *WithToken tests to *WithTuple --- xla/client/xla_builder_test.cc | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/xla/client/xla_builder_test.cc b/xla/client/xla_builder_test.cc index a74e36962329f..60503f890be01 100644 --- a/xla/client/xla_builder_test.cc +++ b/xla/client/xla_builder_test.cc @@ -418,12 +418,11 @@ TEST_F(XlaBuilderTest, AllGatherR2) { ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 64}))); } -TEST_F(XlaBuilderTest, AllGatherWithToken) { +TEST_F(XlaBuilderTest, AllGatherWithTuple) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4}), "x"); auto x2 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {16, 4}), "x2"); - auto t = Parameter(&b, 2, ShapeUtil::MakeScalarShape(F32), "t"); - AllGather(Tuple(&b, {x, x2, t}), /*all_gather_dimension=*/0, + AllGather(Tuple(&b, {x, x2}), /*all_gather_dimension=*/0, /*shard_count=*/4); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); @@ -432,8 +431,7 @@ TEST_F(XlaBuilderTest, AllGatherWithToken) { EXPECT_TRUE(ShapeUtil::Equal( root->shape(), ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {16}), - ShapeUtil::MakeShape(F32, {64, 4}), - ShapeUtil::MakeScalarShape(F32)}))); + ShapeUtil::MakeShape(F32, {64, 4})}))); } TEST_F(XlaBuilderTest, ReduceScatter) { @@ -462,7 +460,7 @@ TEST_F(XlaBuilderTest, ReduceScatter) { ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 8}))); } -TEST_F(XlaBuilderTest, ReduceScatterWithToken) { +TEST_F(XlaBuilderTest, ReduceScatterWithTuple) { XlaBuilder b(TestName()); XlaComputation to_apply; { @@ -476,11 +474,10 @@ TEST_F(XlaBuilderTest, ReduceScatterWithToken) { } auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); auto x2 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {16, 4}), "x2"); - auto t = Parameter(&b, 2, ShapeUtil::MakeScalarShape(F32), "t"); ReplicaGroup group; group.add_replica_ids(0); group.add_replica_ids(1); - ReduceScatter(Tuple(&b, {x, x2, t}), to_apply, /*scatter_dimension=*/1, + ReduceScatter(Tuple(&b, {x, x2}), to_apply, /*scatter_dimension=*/1, /*shard_count=*/2, /*replica_groups=*/{group}); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); @@ -490,8 +487,7 @@ TEST_F(XlaBuilderTest, ReduceScatterWithToken) { EXPECT_TRUE(ShapeUtil::Equal( root->shape(), ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4, 8}), - ShapeUtil::MakeShape(F32, {16, 2}), - ShapeUtil::MakeScalarShape(F32)}))); + ShapeUtil::MakeShape(F32, {16, 2})}))); } TEST_F(XlaBuilderTest, AllToAll) {