Skip to content

Commit

Permalink
Add test for out-of-place reduce-scatter coalesced
Browse files Browse the repository at this point in the history
  • Loading branch information
Arjunbala committed Dec 10, 2023
1 parent af410c0 commit 87fd463
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions test/test_mp_reduce_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,34 @@ def _mp_fn(index):

xm.rendezvous('test_reduce_scatter_list_input')

# Testing reduce-scatter with list input and output
output_list = [
torch.rand((32, shard_size * world_size, 32))
for _ in range(input_list_size)
]
xoutput_list = [output.to(device) for output in output_list]

# TODO: fix the broken case with pin_layout=True
res_list = xm.reduce_scatter(
xm.REDUCE_SUM,
xrand_list,
scale,
scatter_dim,
world_size,
output=xoutput_list,
pin_layout=False)

assert (xoutput_list == res_list)
for i, res in enumerate(xoutput_list):
expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale)
xm.mark_step()

slice_idx = torch.tensor(
list(range(index * shard_size, (index + 1) * shard_size)))
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)
assert res.cpu().allclose(expected)

xm.rendezvous('test_reduce_scatter_list_input_output')
else:
print(
'Default device {} is not a TPU device'.format(device), file=sys.stderr)
Expand Down

0 comments on commit 87fd463

Please sign in to comment.