Skip to content

Commit

Permalink
Change *WithToken tests to *WithTuple
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffhataws committed Oct 9, 2023
1 parent 32e8145 commit b301c2a
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions xla/client/xla_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -418,12 +418,11 @@ TEST_F(XlaBuilderTest, AllGatherR2) {
ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 64})));
}

TEST_F(XlaBuilderTest, AllGatherWithToken) {
TEST_F(XlaBuilderTest, AllGatherWithTuple) {
XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4}), "x");
auto x2 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {16, 4}), "x2");
auto t = Parameter(&b, 2, ShapeUtil::MakeScalarShape(F32), "t");
AllGather(Tuple(&b, {x, x2, t}), /*all_gather_dimension=*/0,
AllGather(Tuple(&b, {x, x2}), /*all_gather_dimension=*/0,
/*shard_count=*/4);
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();
Expand All @@ -432,8 +431,7 @@ TEST_F(XlaBuilderTest, AllGatherWithToken) {
EXPECT_TRUE(ShapeUtil::Equal(
root->shape(),
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {16}),
ShapeUtil::MakeShape(F32, {64, 4}),
ShapeUtil::MakeScalarShape(F32)})));
ShapeUtil::MakeShape(F32, {64, 4})})));
}

TEST_F(XlaBuilderTest, ReduceScatter) {
Expand Down Expand Up @@ -462,7 +460,7 @@ TEST_F(XlaBuilderTest, ReduceScatter) {
ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 8})));
}

TEST_F(XlaBuilderTest, ReduceScatterWithToken) {
TEST_F(XlaBuilderTest, ReduceScatterWithTuple) {
XlaBuilder b(TestName());
XlaComputation to_apply;
{
Expand All @@ -476,11 +474,10 @@ TEST_F(XlaBuilderTest, ReduceScatterWithToken) {
}
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
auto x2 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {16, 4}), "x2");
auto t = Parameter(&b, 2, ShapeUtil::MakeScalarShape(F32), "t");
ReplicaGroup group;
group.add_replica_ids(0);
group.add_replica_ids(1);
ReduceScatter(Tuple(&b, {x, x2, t}), to_apply, /*scatter_dimension=*/1,
ReduceScatter(Tuple(&b, {x, x2}), to_apply, /*scatter_dimension=*/1,
/*shard_count=*/2,
/*replica_groups=*/{group});
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
Expand All @@ -490,8 +487,7 @@ TEST_F(XlaBuilderTest, ReduceScatterWithToken) {
EXPECT_TRUE(ShapeUtil::Equal(
root->shape(),
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4, 8}),
ShapeUtil::MakeShape(F32, {16, 2}),
ShapeUtil::MakeScalarShape(F32)})));
ShapeUtil::MakeShape(F32, {16, 2})})));
}

TEST_F(XlaBuilderTest, AllToAll) {
Expand Down

0 comments on commit b301c2a

Please sign in to comment.