Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZeRO1: Add bucketting logic to control the size of tensors for all-gather/reduce-scatter #6025

Merged
merged 15 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions test/test_mp_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,66 @@ def _mp_fn(index):
file=sys.stderr)
print(f'[{index}] {cpu_result}', file=sys.stderr)
sys.exit(1)

# Testing with a single replica group and tensor list as input (Bucketized)
# TODO: add support for list input with pin_layout=True and output=None
result_list = xm.all_gather_bucketized(
ordinal_tensors, dim=0, pin_layout=False)

for i, result in enumerate(result_list):
cpu_result = result.cpu()
expected = i * 1000 + torch.arange(world_size, dtype=torch.float)
if not cpu_result.allclose(expected):
print(
'xm.all_gather_bucketized() produced wrong reductions for item {i} in result list',
file=sys.stderr)
print(f'[{index}] {cpu_result}', file=sys.stderr)
sys.exit(1)

# Testing with a single replica group and tensor list as input and output!=None (out-of-place) (Bucketized)
# Reuse ordinal_tensors from previous test
output_tensors = [
torch.zeros([world_size], dtype=torch.float).to(device)
for i in range(input_list_size)
jeffhataws marked this conversation as resolved.
Show resolved Hide resolved
]
# TODO: add support for list input with pin_layout=True and output!=None
result_list = xm.all_gather_bucketized(
ordinal_tensors, dim=0, output=output_tensors, pin_layout=False)

for i, result in enumerate(result_list):
jeffhataws marked this conversation as resolved.
Show resolved Hide resolved
cpu_result = result.cpu()
expected = i * 1000 + torch.arange(world_size, dtype=torch.float)
if not cpu_result.allclose(expected):
print(
'xm.all_gather() produced wrong reductions for item {i} in result list',
file=sys.stderr)
print(f'[{index}] {cpu_result}', file=sys.stderr)
sys.exit(1)

# Testing with a single replica group and tensor list as input and output!=None (out-of-place) (Bucketized, zero bucket size)
# Reuse ordinal_tensors from previous test
output_tensors = [
torch.zeros([world_size], dtype=torch.float).to(device)
for i in range(input_list_size)
]
# TODO: add support for list input with pin_layout=True and output!=None
result_list = xm.all_gather_bucketized(
ordinal_tensors,
dim=0,
output=output_tensors,
pin_layout=False,
bucket_cap_mb=0)

for i, result in enumerate(result_list):
cpu_result = result.cpu()
expected = i * 1000 + torch.arange(world_size, dtype=torch.float)
if not cpu_result.allclose(expected):
print(
'xm.all_gather() produced wrong reductions for item {i} in result list',
file=sys.stderr)
print(f'[{index}] {cpu_result}', file=sys.stderr)
sys.exit(1)

# TODO: add test for torch.compile when support for list input is ready

else:
Expand Down
87 changes: 87 additions & 0 deletions test/test_mp_reduce_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,33 @@ def _mp_fn(index):

xm.rendezvous('test_reduce_scatter_list_input')

# Testing reduce-scatter with list input bucketized
rand_list = [
torch.rand((32, shard_size * world_size, 32))
for _ in range(input_list_size)
]
xrand_list = [rand.to(device) for rand in rand_list]

# TODO: fix the broken case with pin_layout=True
res_list = xm.reduce_scatter_bucketized(
xm.REDUCE_SUM,
xrand_list,
scale,
scatter_dim,
world_size,
pin_layout=False)

for i, res in enumerate(res_list):
expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale)
xm.mark_step()

slice_idx = torch.tensor(
list(range(index * shard_size, (index + 1) * shard_size)))
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)
assert res.cpu().allclose(expected)

xm.rendezvous('test_reduce_scatter_list_input_bucketized')

# Testing reduce-scatter with list input and output
output_list = [
torch.rand((32, shard_size * world_size, 32))
Expand Down Expand Up @@ -83,6 +110,66 @@ def _mp_fn(index):
assert res.cpu().allclose(expected)

xm.rendezvous('test_reduce_scatter_list_input_output')

# Testing reduce-scatter with list input and output (buckettized)
output_list = [
torch.rand((32, shard_size * world_size, 32))
for _ in range(input_list_size)
]
xoutput_list = [output.to(device) for output in output_list]

# TODO: fix the broken case with pin_layout=True
res_list = xm.reduce_scatter_bucketized(
xm.REDUCE_SUM,
xrand_list,
scale,
scatter_dim,
world_size,
output=xoutput_list,
pin_layout=False)

assert (xoutput_list == res_list)
for i, res in enumerate(xoutput_list):
expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale)
xm.mark_step()

slice_idx = torch.tensor(
list(range(index * shard_size, (index + 1) * shard_size)))
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)
assert res.cpu().allclose(expected)

xm.rendezvous('test_reduce_scatter_list_input_output_bucketized')

# Testing reduce-scatter with list input and output (buckettized, but zero bucket size)
output_list = [
torch.rand((32, shard_size * world_size, 32))
for _ in range(input_list_size)
]
xoutput_list = [output.to(device) for output in output_list]

# TODO: fix the broken case with pin_layout=True
res_list = xm.reduce_scatter_bucketized(
xm.REDUCE_SUM,
xrand_list,
scale,
scatter_dim,
world_size,
output=xoutput_list,
bucket_cap_mb=0,
pin_layout=False)

assert (xoutput_list == res_list)
for i, res in enumerate(xoutput_list):
expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale)
xm.mark_step()

slice_idx = torch.tensor(
list(range(index * shard_size, (index + 1) * shard_size)))
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)
assert res.cpu().allclose(expected)

xm.rendezvous(
'test_reduce_scatter_list_input_output_bucketized, zero bucket size')
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @JackCaoG , does rendezvous allow comma and space in the rendezvous key? How come this didn't error out?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is not a concern, we can merge this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the implementation of xla_rendezvous, I think tag got ignored so it doesn't really matter.

def xla_rendezvous(payload: bytes = b'',
ordinals: Optional[List[int]] = None,
tag: Optional[str] = None) -> List[bytes]:
"""Share `payload` with all replicas in `ordinals`.
`tag` is ignored except for logging.

else:
print(
'Default device {} is not a TPU device'.format(device), file=sys.stderr)
Expand Down
148 changes: 148 additions & 0 deletions torch_xla/core/xla_model.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,110 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
f"given {type(value)}.")


class CoalescingBuckets(object):

def __init__(self, func, input_list, output_list=None, bucket_cap_mb=160):
if not isinstance(input_list, list) or any(
not isinstance(v, torch.Tensor) for v in input_list):
raise TypeError(
f"`input_list` needs to be a list of Tensors, but given {type(input_list)}."
)
if output_list != None:
if not isinstance(output_list, list) or any(
not isinstance(v, torch.Tensor) for v in output_list):
raise TypeError(
f"`output_list` needs to be a list of Tensors, but given {type(output_list)}."
)
if len(output_list) != len(input_list):
raise ValueError(
"`output_list` length doesn't match `input_list` length: "
f"{len(output_list)} vs {len(input_list)}.")
self._func = func
self._input_list = input_list
self._output_list = output_list
self._total = 0
self._tensor_bucket = []
self._output_bucket = [] if output_list else None
self._bucket_cap = bucket_cap_mb * 1024 * 1024
self._out_tensors = []

def flush(self):
if len(self._tensor_bucket) == 1:
# Use non-coalesced CCOp if its just one tensor
output = self._output_bucket[0] if self._output_bucket else None
self._out_tensors.append(self._func(self._tensor_bucket[0], output))
elif len(self._tensor_bucket):
self._out_tensors.extend(
self._func(self._tensor_bucket, self._output_bucket))
self._total = 0
self._tensor_bucket = []
self._output_bucket = [] if self._output_list else None

def add(self, tensor, idx):
self._total += tensor.numel() * tensor.element_size()
self._tensor_bucket.append(tensor)
if self._output_list != None:
self._output_bucket.append(self._output_list[idx])

def __call__(self):
for idx, tensor in enumerate(self._input_list):
tensor_bytes = tensor.numel() * tensor.element_size()

# Aim for target bucket_cap_mb: flush new tensor with bucket if bucket content
# is small (1/2 cap) but don't combine if combined total is over 2x cap
total_new = self._total + tensor_bytes
if tensor_bytes > self._bucket_cap and self._total < 0.5 * self._bucket_cap and total_new <= 2 * self._bucket_cap:
self.add(tensor, idx)
self.flush()
else:
# Bucketize till the total spills over
if total_new > self._bucket_cap:
self.flush()
self.add(tensor, idx)

# Flush the last remaining bucket
self.flush()

assert len(self._out_tensors) == len(self._input_list)

return self._out_tensors


def all_gather_bucketized(input_list,
dim=0,
groups=None,
output=None,
pin_layout=False,
bucket_cap_mb=160):
"""Performs an all-gather operation along a given dimension, with bucketization.

Args:
See all_gather for the args: dim, groups, output, pin_layout
input_list: List of input tensors
bucket_cap_mb: Number of MegaBytes of the tensor bucket to fill before doing all-gather.

Returns:
A list of tensors each of which has, in the ``dim`` dimension, all the values from the
participating replicas.
"""
# sanity checks
if pin_layout:
raise RuntimeError(
"For xm.all_gather_bucketized, pin_layout=True is not yet supported.")

def _all_gather_coalesced(_input_list, _output_list=None):
return all_gather(
value=_input_list,
dim=dim,
groups=groups,
output=_output_list,
pin_layout=pin_layout)

buckets = CoalescingBuckets(
_all_gather_coalesced, input_list, output, bucket_cap_mb=bucket_cap_mb)
return buckets()


def all_to_all(value,
split_dimension,
concat_dimension,
Expand Down Expand Up @@ -846,6 +950,50 @@ def reduce_scatter(reduce_type,
f"given {type(input)}.")


def reduce_scatter_bucketized(reduce_type,
input_list,
scale,
scatter_dim,
shard_count,
groups=None,
output=None,
pin_layout=False,
bucket_cap_mb=160):
"""Performs a XLA `ReduceScatter()` operation on a list of tensors (bucketized).

See: https://www.tensorflow.org/xla/operation_semantics#reducescatter

Args:
see reduce_scatter for reduce_type, scale, scatter_dim, shard_count, groups, pin_layout
input_list: List of input tensors
output: Optional list of output torch.Tensor
bucket_cap_mb: Number of MegaBytes of the tensor bucket to fill before doing all-gather.

Returns:
A list of `torch.Tensors` with all the values reduced across replicas. Each process
gets a shard split along the `scatter_dim`. All other dimensions are
the same as the input.
"""

def _reduce_scatter_coalesced(_input_list, _output_list=None):
return reduce_scatter(
reduce_type=reduce_type,
input=_input_list,
scale=scale,
scatter_dim=scatter_dim,
shard_count=shard_count,
groups=groups,
output=_output_list,
pin_layout=pin_layout)
jeffhataws marked this conversation as resolved.
Show resolved Hide resolved

buckets = CoalescingBuckets(
_reduce_scatter_coalesced,
input_list,
output,
bucket_cap_mb=bucket_cap_mb)
return buckets()


def add_step_closure(closure, args=(), run_async=False):
"""Adds a closure to the list of the ones to be run at the end of the step.

Expand Down
Loading
Loading