Skip to content

Commit

Permalink
Add reduce-scatter test with list of tensors as input/output
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffhataws committed Dec 6, 2023
1 parent 3d70fc0 commit 69968e5
Showing 1 changed file with 29 additions and 1 deletion.
30 changes: 29 additions & 1 deletion test/test_mp_reduce_scatter.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@ def _mp_fn(index):
scale = 1 / world_size
scatter_dim = 1
shard_size = 2
input_list_size = 5

if xm.xla_device_hw(device) in ['TPU', 'CUDA']:
rand = torch.rand((32, shard_size * world_size, 32))
@@ -25,8 +26,35 @@ def _mp_fn(index):
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)

assert res.cpu().allclose(expected)

xm.rendezvous('test_reduce_scatter')

# Testing reduce-scatter with list input
rand_list = [
torch.rand((32, shard_size * world_size, 32))
for _ in range(input_list_size)
]
xrand_list = [rand.to(device) for rand in rand_list]

# TODO: fix the broken case with pin_layout=True
res_list = xm.reduce_scatter(
xm.REDUCE_SUM,
xrand_list,
scale,
scatter_dim,
world_size,
pin_layout=False)

for i, res in enumerate(res_list):
expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale)
xm.mark_step()

slice_idx = torch.tensor(
list(range(index * shard_size, (index + 1) * shard_size)))
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)
assert res.cpu().allclose(expected)

xm.rendezvous('test_reduce_scatter_list_input')

else:
print(
'Default device {} is not a TPU device'.format(device), file=sys.stderr)

0 comments on commit 69968e5

Please sign in to comment.