Skip to content

Commit

Permalink
Refactor bucketized all-gather/reduce-scatter functions; add bucket_c…
Browse files Browse the repository at this point in the history
…ap_mb arg
  • Loading branch information
jeffhataws authored and EC2 Default User committed Mar 20, 2024
1 parent 8586370 commit 675e7a1
Show file tree
Hide file tree
Showing 4 changed files with 294 additions and 140 deletions.
36 changes: 36 additions & 0 deletions test/test_mp_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,42 @@ def _mp_fn(index):
file=sys.stderr)
print(f'[{index}] {cpu_result}', file=sys.stderr)
sys.exit(1)

# Testing with a single replica group and tensor list as input (Bucketized)
# TODO: add support for list input with pin_layout=True and output=None
result_list = xm.all_gather_bucketized(
ordinal_tensors, dim=0, pin_layout=False)

for i, result in enumerate(result_list):
cpu_result = result.cpu()
expected = i * 1000 + torch.arange(world_size, dtype=torch.float)
if not cpu_result.allclose(expected):
print(
'xm.all_gather_bucketized() produced wrong reductions for item {i} in result list',
file=sys.stderr)
print(f'[{index}] {cpu_result}', file=sys.stderr)
sys.exit(1)

# Testing with a single replica group and tensor list as input and output!=None (out-of-place) (Bucketized)
# Reuse ordinal_tensors from previous test
output_tensors = [
torch.zeros([world_size], dtype=torch.float).to(device)
for i in range(input_list_size)
]
# TODO: add support for list input with pin_layout=True and output!=None
result_list = xm.all_gather_bucketized(
ordinal_tensors, dim=0, output=output_tensors, pin_layout=False)

for i, result in enumerate(result_list):
cpu_result = result.cpu()
expected = i * 1000 + torch.arange(world_size, dtype=torch.float)
if not cpu_result.allclose(expected):
print(
'xm.all_gather() produced wrong reductions for item {i} in result list',
file=sys.stderr)
print(f'[{index}] {cpu_result}', file=sys.stderr)
sys.exit(1)

# TODO: add test for torch.compile when support for list input is ready

else:
Expand Down
56 changes: 56 additions & 0 deletions test/test_mp_reduce_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,33 @@ def _mp_fn(index):

xm.rendezvous('test_reduce_scatter_list_input')

# Testing reduce-scatter with list input bucketized
rand_list = [
torch.rand((32, shard_size * world_size, 32))
for _ in range(input_list_size)
]
xrand_list = [rand.to(device) for rand in rand_list]

# TODO: fix the broken case with pin_layout=True
res_list = xm.reduce_scatter_bucketized(
xm.REDUCE_SUM,
xrand_list,
scale,
scatter_dim,
world_size,
pin_layout=False)

for i, res in enumerate(res_list):
expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale)
xm.mark_step()

slice_idx = torch.tensor(
list(range(index * shard_size, (index + 1) * shard_size)))
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)
assert res.cpu().allclose(expected)

xm.rendezvous('test_reduce_scatter_list_input_bucketized')

# Testing reduce-scatter with list input and output
output_list = [
torch.rand((32, shard_size * world_size, 32))
Expand Down Expand Up @@ -83,6 +110,35 @@ def _mp_fn(index):
assert res.cpu().allclose(expected)

xm.rendezvous('test_reduce_scatter_list_input_output')

# Testing reduce-scatter with list input and output
output_list = [
torch.rand((32, shard_size * world_size, 32))
for _ in range(input_list_size)
]
xoutput_list = [output.to(device) for output in output_list]

# TODO: fix the broken case with pin_layout=True
res_list = xm.reduce_scatter_bucketized(
xm.REDUCE_SUM,
xrand_list,
scale,
scatter_dim,
world_size,
output=xoutput_list,
pin_layout=False)

assert (xoutput_list == res_list)
for i, res in enumerate(xoutput_list):
expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale)
xm.mark_step()

slice_idx = torch.tensor(
list(range(index * shard_size, (index + 1) * shard_size)))
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)
assert res.cpu().allclose(expected)

xm.rendezvous('test_reduce_scatter_list_input_output_bucketized')
else:
print(
'Default device {} is not a TPU device'.format(device), file=sys.stderr)
Expand Down
Loading

0 comments on commit 675e7a1

Please sign in to comment.