From ec4b1e05bb7150f1e975aad5872f9a0dfd51a131 Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Wed, 20 Mar 2024 05:48:30 +0000 Subject: [PATCH] Refactor bucketized all-gather/reduce-scatter functions; add bucket_cap_mb arg --- test/test_mp_all_gather.py | 36 ++ test/test_mp_reduce_scatter.py | 56 +++ torch_xla/core/xla_model.py | 327 ++++++++++-------- .../distributed/zero_redundancy_optimizer.py | 14 +- 4 files changed, 294 insertions(+), 139 deletions(-) diff --git a/test/test_mp_all_gather.py b/test/test_mp_all_gather.py index 5d38158d2871..be83b74631e1 100644 --- a/test/test_mp_all_gather.py +++ b/test/test_mp_all_gather.py @@ -95,6 +95,42 @@ def _mp_fn(index): file=sys.stderr) print(f'[{index}] {cpu_result}', file=sys.stderr) sys.exit(1) + + # Testing with a single replica group and tensor list as input (Bucketized) + # TODO: add support for list input with pin_layout=True and output=None + result_list = xm.all_gather_bucketized( + ordinal_tensors, dim=0, pin_layout=False) + + for i, result in enumerate(result_list): + cpu_result = result.cpu() + expected = i * 1000 + torch.arange(world_size, dtype=torch.float) + if not cpu_result.allclose(expected): + print( + 'xm.all_gather_bucketized() produced wrong reductions for item {i} in result list', + file=sys.stderr) + print(f'[{index}] {cpu_result}', file=sys.stderr) + sys.exit(1) + + # Testing with a single replica group and tensor list as input and output!=None (out-of-place) (Bucketized) + # Reuse ordinal_tensors from previous test + output_tensors = [ + torch.zeros([world_size], dtype=torch.float).to(device) + for i in range(input_list_size) + ] + # TODO: add support for list input with pin_layout=True and output!=None + result_list = xm.all_gather_bucketized( + ordinal_tensors, dim=0, output=output_tensors, pin_layout=False) + + for i, result in enumerate(result_list): + cpu_result = result.cpu() + expected = i * 1000 + torch.arange(world_size, dtype=torch.float) + if not cpu_result.allclose(expected): + print( + 'xm.all_gather() produced wrong reductions for item {i} in result list', + file=sys.stderr) + print(f'[{index}] {cpu_result}', file=sys.stderr) + sys.exit(1) + # TODO: add test for torch.compile when support for list input is ready else: diff --git a/test/test_mp_reduce_scatter.py b/test/test_mp_reduce_scatter.py index 1ef61d3aa794..94363be1a02a 100644 --- a/test/test_mp_reduce_scatter.py +++ b/test/test_mp_reduce_scatter.py @@ -55,6 +55,33 @@ def _mp_fn(index): xm.rendezvous('test_reduce_scatter_list_input') + # Testing reduce-scatter with list input bucketized + rand_list = [ + torch.rand((32, shard_size * world_size, 32)) + for _ in range(input_list_size) + ] + xrand_list = [rand.to(device) for rand in rand_list] + + # TODO: fix the broken case with pin_layout=True + res_list = xm.reduce_scatter_bucketized( + xm.REDUCE_SUM, + xrand_list, + scale, + scatter_dim, + world_size, + pin_layout=False) + + for i, res in enumerate(res_list): + expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale) + xm.mark_step() + + slice_idx = torch.tensor( + list(range(index * shard_size, (index + 1) * shard_size))) + expected = expected_world.cpu().index_select(scatter_dim, slice_idx) + assert res.cpu().allclose(expected) + + xm.rendezvous('test_reduce_scatter_list_input_bucketized') + # Testing reduce-scatter with list input and output output_list = [ torch.rand((32, shard_size * world_size, 32)) @@ -83,6 +110,35 @@ def _mp_fn(index): assert res.cpu().allclose(expected) xm.rendezvous('test_reduce_scatter_list_input_output') + + # Testing reduce-scatter with list input and output + output_list = [ + torch.rand((32, shard_size * world_size, 32)) + for _ in range(input_list_size) + ] + xoutput_list = [output.to(device) for output in output_list] + + # TODO: fix the broken case with pin_layout=True + res_list = xm.reduce_scatter_bucketized( + xm.REDUCE_SUM, + xrand_list, + scale, + scatter_dim, + world_size, + output=xoutput_list, + pin_layout=False) + + assert (xoutput_list == res_list) + for i, res in enumerate(xoutput_list): + expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale) + xm.mark_step() + + slice_idx = torch.tensor( + list(range(index * shard_size, (index + 1) * shard_size))) + expected = expected_world.cpu().index_select(scatter_dim, slice_idx) + assert res.cpu().allclose(expected) + + xm.rendezvous('test_reduce_scatter_list_input_output_bucketized') else: print( 'Default device {} is not a TPU device'.format(device), file=sys.stderr) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index f1f7051dc46b..8b1f1c2ce259 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -608,7 +608,6 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): # Now the input should be a list of Tensors. elif isinstance(value, list) and all( isinstance(v, torch.Tensor) for v in value): - # sanity checks if pin_layout: raise RuntimeError( "For xm.all_gather with list of tensors input, pin_layout=True is not yet supported." @@ -620,89 +619,113 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): f"`output` needs to be a list of Tensors, but given {type(output)}." ) if len(output) != len(value): - raise ValueError("`output` length doesn't match `value` length: " - f"{len(output)} vs {len(value)}.") - - # helper function for bucketing - def _all_gather_coalesced(tensor_list, output_list=[]): - if output_list != []: - if len(output_list) != len(tensor_list): - raise ValueError( - "_all_gather_coalesced: `output_list` length doesn't match `tensor_list` length: " - f"{len(output_list)} vs {len(tensor_list)}.") - # Call the out of place version of the reduce_scatter - new_token = torch_xla._XLAC._xla_all_gather_coalesced_out( - output_list, tensor_list, token, dim, shard_count, groups or [], - pin_layout) - torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) - return output_list - - result = torch_xla._XLAC._xla_all_gather_coalesced( - tensor_list, token, dim, shard_count, groups or [], pin_layout) - torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) - return result[:-1] - - #return _all_gather_coalesced(value, output if output else []) - - total = 0 - tensor_bucket = [] - output_bucket = [] - out_tensors = [] - bucket_cap = int( - os.getenv("ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB", - _ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB)) * 1024 * 1024 - if groups: - divisor = len(groups[0]) if type(groups[0]) == list else len(groups) - else: - divisor = xrt_world_size() - bucket_cap = bucket_cap / divisor - - for idx, tensor in enumerate(value): - tensor_bytes = tensor.numel() * tensor.element_size() - - # Tensor is larger than bucket_cap, don't bucketize - if tensor_bytes > bucket_cap: - # Flush out previous buckets even if they don't fill up - if total >= 0.5 * bucket_cap or (total + tensor_bytes) > 2 * bucket_cap: - if len(tensor_bucket): - out_tensors.extend( - _all_gather_coalesced(tensor_bucket, output_bucket)) - out_tensors.extend( - _all_gather_coalesced([tensor], [output[idx]] if output else [])) - else: - tensor_bucket.append(tensor) - if output != None: - output_bucket.append(output[idx]) - out_tensors.extend( - _all_gather_coalesced(tensor_bucket, output_bucket)) - total = 0 - tensor_bucket = [] - output_bucket = [] - continue + raise ValueError("`output` length doesn't match `input` length: " + f"{len(output)} vs {len(input)}.") + # Call the out of place version of the reduce_scatter + new_token = torch_xla._XLAC._xla_all_gather_coalesced_out( + output, value, token, dim, shard_count, groups or [], pin_layout) + torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) + return output + result = torch_xla._XLAC._xla_all_gather_coalesced(value, token, dim, + shard_count, groups or + [], pin_layout) + torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) + return result[:-1] + else: + raise TypeError("`value` needs to be a Tensor or a list of Tensors, but " + f"given {type(value)}.") + + +def all_gather_bucketized(input_list, + dim=0, + groups=None, + output=None, + pin_layout=True, + bucket_cap_mb=160): + """Performs an all-gather operation along a given dimension, with bucketization. + + Args: + See all_gather for the args: dim, groups, output, pin_layout + input_list: List of input tensors + bucket_cap_mb: Number of MegaBytes of the tensor bucket to fill before doing all-gather. + + Returns: + A list of tensors each of which has, in the ``dim`` dimension, all the values from the + participating replicas. + """ + # sanity checks + if pin_layout: + raise RuntimeError( + "For xm.all_gather_bucketized, pin_layout=True is not yet supported.") + if not isinstance(input_list, list) or any( + not isinstance(v, torch.Tensor) for v in input_list): + raise TypeError( + f"`input_list` needs to be a list of Tensors, but given {type(input_list)}." + ) + if output != None: + if not isinstance(output, list) or any( + not isinstance(v, torch.Tensor) for v in output): + raise TypeError( + f"`output` needs to be a list of Tensors, but given {type(output)}.") + if len(output) != len(input_list): + raise ValueError("`output` length doesn't match `input_list` length: " + f"{len(output)} vs {len(input_list)}.") + + def _all_gather_coalesced(_input_list, _output_list=None): + return all_gather( + value=_input_list, + dim=dim, + groups=groups, + output=_output_list, + pin_layout=pin_layout) + + total = 0 + tensor_bucket = [] + output_bucket = [] if output else None + out_tensors = [] + bucket_cap = bucket_cap_mb * 1024 * 1024 + if groups: + divisor = len(groups[0]) if type(groups[0]) == list else len(groups) + else: + divisor = xrt_world_size() + bucket_cap = bucket_cap / divisor + + for idx, tensor in enumerate(input_list): + tensor_bytes = tensor.numel() * tensor.element_size() + + # Aim for target bucket_cap_mb: flush new tensor with bucket if bucket content + # is small (1/2 cap) but don't combine if combined total is over 2x cap + total_new = total + tensor_bytes + if tensor_bytes > bucket_cap and total < 0.5 * bucket_cap and total_new <= 2 * bucket_cap: + tensor_bucket.append(tensor) + if output != None: + output_bucket.append(output[idx]) + out_tensors.extend(_all_gather_coalesced(tensor_bucket, output_bucket)) + total = 0 + tensor_bucket = [] + output_bucket = [] if output else None + else: # Bucketize till the total spills over - total += tensor_bytes - if total > bucket_cap: + if total_new > bucket_cap: if len(tensor_bucket): out_tensors.extend( _all_gather_coalesced(tensor_bucket, output_bucket)) - total = tensor_bytes + total = 0 tensor_bucket = [] - output_bucket = [] + output_bucket = [] if output else None + total = total_new tensor_bucket.append(tensor) if output != None: output_bucket.append(output[idx]) - # Flush the last remaining bucket - if len(tensor_bucket): - out_tensors.extend(_all_gather_coalesced(tensor_bucket, output_bucket)) + # Flush the last remaining bucket + if len(tensor_bucket): + out_tensors.extend(_all_gather_coalesced(tensor_bucket, output_bucket)) - assert len(out_tensors) == len(value) + assert len(out_tensors) == len(input_list) - return out_tensors - else: - raise TypeError("`value` needs to be a Tensor or a list of Tensors, but " - f"given {type(value)}.") + return out_tensors def all_to_all(value, @@ -892,7 +915,6 @@ def reduce_scatter(reduce_type, # Now the input should be a list of Tensors. elif isinstance(input, list) and all( isinstance(v, torch.Tensor) for v in input): - # sanity checks if output != None: if not isinstance(output, list) or any( not isinstance(v, torch.Tensor) for v in output): @@ -902,81 +924,116 @@ def reduce_scatter(reduce_type, if len(output) != len(input): raise ValueError("`output` length doesn't match `input` length: " f"{len(output)} vs {len(input)}.") - # helper function for bucketing - def _reduce_scatter_coalesced(tensor_list, output_list=[]): - if output_list != []: - if len(output_list) != len(tensor_list): - raise ValueError( - "_reduce_scatter_coalesced: `output_list` length doesn't match `tensor_list` length: " - f"{len(output_list)} vs {len(tensor_list)}.") - # Call the out of place version of the reduce_scatter - new_token = torch_xla._XLAC._xla_reduce_scatter_coalesced_out( - reduce_type, output_list, tensor_list, token, scale, scatter_dim, - shard_count, groups or [], pin_layout) - torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) - return output_list - - result = torch_xla._XLAC._xla_reduce_scatter_coalesced( - reduce_type, tensor_list, token, scale, scatter_dim, shard_count, + # Call the out of place version of the reduce_scatter + new_token = torch_xla._XLAC._xla_reduce_scatter_coalesced_out( + reduce_type, output, input, token, scale, scatter_dim, shard_count, groups or [], pin_layout) - torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) - return result[:-1] - - total = 0 - tensor_bucket = [] - output_bucket = [] - out_tensors = [] - bucket_cap = int( - os.getenv("ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB", - _ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB)) * 1024 * 1024 - for idx, tensor in enumerate(input): - tensor_bytes = tensor.numel() * tensor.element_size() - - # Tensor is larger than bucket_cap, don't bucketize - if tensor_bytes > bucket_cap: - # Flush out previous buckets even if they don't fill up - if total >= 0.5 * bucket_cap or (total + tensor_bytes) > 2 * bucket_cap: - if len(tensor_bucket): - out_tensors.extend( - _reduce_scatter_coalesced(tensor_bucket, output_bucket)) - out_tensors.extend( - _reduce_scatter_coalesced([tensor], - [output[idx]] if output else [])) - else: - tensor_bucket.append(tensor) - if output != None: - output_bucket.append(output[idx]) - out_tensors.extend( - _reduce_scatter_coalesced(tensor_bucket, output_bucket)) - total = 0 - tensor_bucket = [] - output_bucket = [] - continue + torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) + return output + result = torch_xla._XLAC._xla_reduce_scatter_coalesced( + reduce_type, input, token, scale, scatter_dim, shard_count, groups or + [], pin_layout) + torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) + return result[:-1] + else: + raise TypeError("`input` needs to be a Tensor or a list of Tensors, but " + f"given {type(input)}.") + + +def reduce_scatter_bucketized(reduce_type, + input_list, + scale, + scatter_dim, + shard_count, + groups=None, + output=None, + pin_layout=True, + bucket_cap_mb=160): + """Performs a XLA `ReduceScatter()` operation on a list of tensors (bucketized). + + See: https://www.tensorflow.org/xla/operation_semantics#reducescatter + + Args: + see reduce_scatter for reduce_type, scale, scatter_dim, shard_count, groups, pin_layout + input_list: List of input tensors + output: Optional list of output torch.Tensor + bucket_cap_mb: Number of MegaBytes of the tensor bucket to fill before doing all-gather. + + Returns: + A list of `torch.Tensors` with all the values reduced across replicas. Each process + gets a shard split along the `scatter_dim`. All other dimensions are + the same as the input. + """ + token, devctx = _get_all_reduce_token() + + if not isinstance(input_list, list) or any( + not isinstance(v, torch.Tensor) for v in input_list): + raise TypeError( + f"`input_list` needs to be a list of Tensors, but given {type(input_list)}." + ) + if output != None: + if not isinstance(output, list) or any( + not isinstance(v, torch.Tensor) for v in output): + raise TypeError( + f"`output` needs to be a list of Tensors, but given {type(output)}.") + if len(output) != len(input_list): + raise ValueError("`output` length doesn't match `input_list` length: " + f"{len(output)} vs {len(input_list)}.") + + def _reduce_scatter_coalesced(_input_list, _output_list=None): + return reduce_scatter( + reduce_type=reduce_type, + input=_input_list, + scale=scale, + scatter_dim=scatter_dim, + shard_count=shard_count, + groups=groups, + output=_output_list, + pin_layout=pin_layout) + + total = 0 + tensor_bucket = [] + output_bucket = [] if output else None + out_tensors = [] + bucket_cap = bucket_cap_mb * 1024 * 1024 + + for idx, tensor in enumerate(input_list): + tensor_bytes = tensor.numel() * tensor.element_size() + + # Aim for target bucket_cap_mb: flush new tensor with bucket if bucket content + # is small (1/2 cap) but don't combine if combined total is over 2x cap + total_new = total + tensor_bytes + if tensor_bytes > bucket_cap and total < 0.5 * bucket_cap and total_new <= 2 * bucket_cap: + tensor_bucket.append(tensor) + if output != None: + output_bucket.append(output[idx]) + out_tensors.extend( + _reduce_scatter_coalesced(tensor_bucket, output_bucket)) + total = 0 + tensor_bucket = [] + output_bucket = [] if output else None + else: # Bucketize till the total spills over - total += tensor_bytes - if total > bucket_cap: + if total_new > bucket_cap: if len(tensor_bucket): out_tensors.extend( _reduce_scatter_coalesced(tensor_bucket, output_bucket)) - total = tensor_bytes + total = 0 tensor_bucket = [] - output_bucket = [] + output_bucket = [] if output else None + total = total_new tensor_bucket.append(tensor) if output != None: output_bucket.append(output[idx]) - # Flush the last remaining bucket - if len(tensor_bucket): - out_tensors.extend( - _reduce_scatter_coalesced(tensor_bucket, output_bucket)) + # Flush the last remaining bucket + if len(tensor_bucket): + out_tensors.extend(_reduce_scatter_coalesced(tensor_bucket, output_bucket)) - assert len(out_tensors) == len(input) + assert len(out_tensors) == len(input_list) - return out_tensors - else: - raise TypeError("`input` needs to be a Tensor or a list of Tensors, but " - f"given {type(input)}.") + return out_tensors def add_step_closure(closure, args=(), run_async=False): diff --git a/torch_xla/distributed/zero_redundancy_optimizer.py b/torch_xla/distributed/zero_redundancy_optimizer.py index c46642fa448b..e522125d3350 100644 --- a/torch_xla/distributed/zero_redundancy_optimizer.py +++ b/torch_xla/distributed/zero_redundancy_optimizer.py @@ -40,6 +40,9 @@ class ZeroRedundancyOptimizer(Optimizer): If specified, ZeRO-1 will use this ``grad_norm_groups`` for the EXTRA all-reduce op in grad norm calculation. This can be model parallel groups when mixing ZeRO-1 with model parallelism such as Megatron. + bucket_cap_mb: + If non-zero, specifies the maximum number of megabytes to combine tensors + before doing the all-gather/reduce-scatter operations. **defaults: any trailing arguments, which are forwarded to the local optimizer. @@ -60,7 +63,7 @@ def __init__( sharding_groups: Optional[Any] = None, grad_norm_groups: Optional[Any] = None, lazy_init: bool = False, - coalesce_cc: bool = False, + bucket_cap_mb: int = 160, **defaults: Any, ): super().__init__(params, defaults) @@ -77,7 +80,8 @@ def __init__( self.grad_clipping = grad_clipping self.max_norm = max_norm if max_norm is not None else 1.0 self.pin_layout = pin_layout - self.coalesce_cc = coalesce_cc + self.bucket_cap_mb = bucket_cap_mb + self.coalesce_cc = bucket_cap_mb > 0 self._grad_norm = None @@ -295,7 +299,7 @@ def step(self, closure=None, **kwargs): shard.grad = grad_shard if self.coalesce_cc: - grad_shards = xm.reduce_scatter( + grad_shards = xm.reduce_scatter_bucketized( xm.REDUCE_SUM, padded_grads, scale=1.0 / self.local_world_size, @@ -303,6 +307,7 @@ def step(self, closure=None, **kwargs): shard_count=self.local_world_size, pin_layout=self.pin_layout, groups=self.sharding_groups, + bucket_cap_mb=self.bucket_cap_mb, ) index = 0 for param_group, sharded_param_group in zip( @@ -349,11 +354,12 @@ def step(self, closure=None, **kwargs): param.data.copy_(padded_param.data[:param.size(0)]) if self.coalesce_cc: - padded_params = xm.all_gather( + padded_params = xm.all_gather_bucketized( sharded_data, dim=0, pin_layout=self.pin_layout, groups=self.sharding_groups, + bucket_cap_mb=self.bucket_cap_mb, ) index = 0 for param_group, sharded_param_group in zip(