From 90eda15190948747e56207714333ace943ffa9d7 Mon Sep 17 00:00:00 2001 From: Rahul Solanki Date: Sun, 19 Nov 2023 01:19:43 +0000 Subject: [PATCH 01/15] add bucketting logic to control the size of tensors for all-gather and reduce-scatter --- torch_xla/core/xla_model.py | 150 +++++++++++++++--- .../distributed/zero_redundancy_optimizer.py | 87 +++++++--- 2 files changed, 197 insertions(+), 40 deletions(-) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 28622fdafc2..0667ef74b7b 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -1,6 +1,7 @@ import io import itertools import logging +import os import sys import re import threading @@ -38,6 +39,9 @@ XLA_LIB = Library("xla", "DEF") +# Default bucket size for all-gather and reduce-scatter +_ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB = 160 + def _init_world_size_ordinal(): global _WORLD_SIZE, _ORDINAL @@ -608,31 +612,84 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): raise RuntimeError( "For xm.all_gather with list of tensors input, pin_layout=True is not yet supported." ) - if output != None: - if not isinstance(output, list) or any( - not isinstance(v, torch.Tensor) for v in output): - raise TypeError( - f"`output` needs to be a list of Tensors, but given {type(output)}." - ) - if len(output) != len(value): - raise ValueError("`output` length doesn't match `input` length: " - f"{len(output)} vs {len(input)}.") - # Call the out of place version of the reduce_scatter - new_token = torch_xla._XLAC._xla_all_gather_coalesced_out( - output, value, token, dim, shard_count, groups or [], pin_layout) - torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) - return output - - result = torch_xla._XLAC._xla_all_gather_coalesced(value, token, dim, - shard_count, groups or - [], pin_layout) - torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) - return result[:-1] + def _all_gather_coalesced(tensor_list, output_list=None): + if output_list != None: + if not isinstance(output_list, list) or any( + not isinstance(v, torch.Tensor) for v in output_list): + raise TypeError( + f"`output` needs to be a list of Tensors, but given {type(output_list)}." + ) + if len(output_list) != len(tensor_list): + raise ValueError("`output` length doesn't match `input` length: " + f"{len(output_list)} vs {len(tensor_list)}.") + # Call the out of place version of the reduce_scatter + new_token = torch_xla._XLAC._xla_all_gather_coalesced_out( + output_list, tensor_list, token, dim, shard_count, groups or [], pin_layout) + torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) + return output_list + + result = torch_xla._XLAC._xla_all_gather_coalesced(tensor_list, token, dim, + shard_count, groups or + [], pin_layout) + torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) + return result[:-1] + + total = 0 + tensor_bucket = [] + output_bucket = [] + out_tensors = [] + bucket_cap = int(os.getenv( + "ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB", + _ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB + )) * 1024 * 1024 + divisor = len(groups[0]) if type(groups[0]) == list else len(groups) + bucket_cap = bucket_cap / divisor + for idx, tensor in enumerate(value): + + tensor_bytes = tensor.numel() * tensor.element_size() + output_selected = None + if output != None: + output_selected = [output[idx]] + if tensor.numel() != output_selected.numel(): + raise ValueError(f"`output` tensor size doesn't match `input` tensor size for list index {idx}: " + f"{output_selected.numel() vs tensor.numel().") + + # Tensor is larger than bucket_cap, don't bucketize + if tensor_bytes > bucket_cap: + # Flush out previous buckets even if they don't fill up + if total >= 0.5*bucket_cap or (total + tensor_bytes) > 2*bucket_cap: + out_tensors.extend(_all_gather_coalesced(tensor_bucket, output_bucket)) + out_tensors.extend(_all_gather_coalesced([tensor], output_selected)) + else: + tensor_bucket.append(tensor) + output_bucket.append(output_selected) + out_tensors.extend(_all_gather_coalesced(tensor_bucket, output_bucket)) + total = 0 + tensor_bucket = [] + output_bucket = [] + continue + + # Bucketize till the total spills over + total += tensor_bytes + if total > bucket_cap: + out_tensors.extend(_all_gather_coalesced(tensor_bucket, output_bucket)) + total = tensor_bytes + tensor_bucket = [] + output_bucket = [] + tensor_bucket.append(tensor) + output_bucket.append(output_selected) + + # Flush the last remaining bucket + if len(tensor_bucket): + out_tensors.extend(_all_gather_coalesced(tensor_bucket, output_bucket)) + + assert len(out_tensors) == len(value) + + return out_tensors else: raise TypeError("`value` needs to be a Tensor or a list of Tensors, but " f"given {type(value)}.") - def all_to_all(value, split_dimension, concat_dimension, @@ -836,11 +893,62 @@ def reduce_scatter(reduce_type, torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) return output +<<<<<<< HEAD result = torch_xla._XLAC._xla_reduce_scatter_coalesced( reduce_type, input, token, scale, scatter_dim, shard_count, groups or [], pin_layout) torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) return result[:-1] +======= + def _reduce_scatter_coalesced(tensor_list, out_tensor_bucket): + result = torch_xla._XLAC._xla_reduce_scatter_coalesced( + reduce_type, out_tensor_bucket, tensor_list, token, scale, scatter_dim, shard_count, + groups or [], pin_layout) + torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) + return result[:-1] + + total = 0 + tensor_bucket = [] + out_tensor_bucket = [] + out_tensors = [] + bucket_cap = int(os.getenv( + "ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB", + _ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB + )) * 1024 * 1024 + for i, tensor in enumerate(input): + tensor_bytes = tensor.numel() * tensor.element_size() + + # Tensor is larger than bucket_cap, don't bucketize + if tensor_bytes > bucket_cap: + # Flush out previous buckets even if they don't fill up + if total >= 0.5*bucket_cap or (total + tensor_bytes) > 2*bucket_cap: + out_tensors.extend(_reduce_scatter_coalesced(tensor_bucket, out_tensor_bucket)) + out_tensors.extend(_reduce_scatter_coalesced([tensor], [output[i]] if output else [])) + else: + tensor_bucket.append(tensor) + if output != None: + out_tensor_bucket.append(output[i]) + out_tensors.extend(_reduce_scatter_coalesced(tensor_bucket, out_tensor_bucket)) + total = 0 + tensor_bucket = [] + continue + + # Bucketize till the total spills over + total += tensor_bytes + if total > bucket_cap: + out_tensors.extend(_reduce_scatter_coalesced(tensor_bucket, out_tensor_bucket)) + total = tensor_bytes + tensor_bucket = [] + out_tensor_bucket = [] + tensor_bucket.append(tensor) + if output != None: + out_tensor_bucket.append(output[i]) + + # Flush the last remaining bucket + if len(tensor_bucket): + out_tensors.extend(_reduce_scatter_coalesced(tensor_bucket, out_tensor_bucket)) + return out_tensors +>>>>>>> add bucketting logic to control the size of tensors for all-gather and reduce-scatter else: raise TypeError("`input` needs to be a Tensor or a list of Tensors, but " f"given {type(input)}.") diff --git a/torch_xla/distributed/zero_redundancy_optimizer.py b/torch_xla/distributed/zero_redundancy_optimizer.py index f00929eeb86..7127c923c75 100644 --- a/torch_xla/distributed/zero_redundancy_optimizer.py +++ b/torch_xla/distributed/zero_redundancy_optimizer.py @@ -60,6 +60,7 @@ def __init__( sharding_groups: Optional[Any] = None, grad_norm_groups: Optional[Any] = None, lazy_init: bool = False, + coalesce_cc: bool = False, **defaults: Any, ): super().__init__(params, defaults) @@ -76,6 +77,7 @@ def __init__( self.grad_clipping = grad_clipping self.max_norm = max_norm if max_norm is not None else 1.0 self.pin_layout = pin_layout + self.coalesce_cc = coalesce_cc self.inited = False if not lazy_init: @@ -256,6 +258,7 @@ def step(self, closure=None, **kwargs): # Reduce full gradients across ranks # Assign gradient shards to the respective parameter shards + padded_grads = [] for param_group, sharded_param_group in zip( self.param_groups, self.base_optimizer.param_groups): for param, shard in zip(param_group['params'], @@ -263,19 +266,44 @@ def step(self, closure=None, **kwargs): if param.grad is not None: padded_grad = self._pad_to_world_size(param.grad, self.local_world_size) - grad_shard = xm.reduce_scatter( - xm.REDUCE_SUM, - padded_grad, - scale=1.0 / self.local_world_size, - scatter_dim=0, - shard_count=self.local_world_size, - pin_layout=self.pin_layout, - groups=self.sharding_groups, - ) - - if grad_shard.dtype != self.optimizer_dtype: - grad_shard = grad_shard.to(dtype=self.optimizer_dtype) + if self.coalesce_cc: + padded_grads.append(padded_grad) + else: + grad_shard = xm.reduce_scatter( + xm.REDUCE_SUM, + padded_grad, + scale=1.0 / self.local_world_size, + scatter_dim=0, + shard_count=self.local_world_size, + pin_layout=self.pin_layout, + groups=self.sharding_groups, + ) + if grad_shard.dtype != self.optimizer_dtype: + grad_shard = grad_shard.to(dtype=self.optimizer_dtype) + shard.grad = grad_shard + + if self.coalesce_cc: + grad_shard = xm.reduce_scatter( + xm.REDUCE_SUM, + padded_grads, + scale=1.0 / self.local_world_size, + scatter_dim=0, + shard_count=self.local_world_size, + pin_layout=self.pin_layout, + groups=self.sharding_groups, + ) + index = 0 + for param_group, sharded_param_group in zip( + self.param_groups, self.base_optimizer.param_groups): + for param, shard in zip(param_group['params'], + sharded_param_group['params']): + if param.grad is not None: + grad_shard = grad_shards[index] + + if grad_shard.dtype != self.optimizer_dtype: + grad_shard = grad_shard.to(dtype=self.optimizer_dtype) shard.grad = grad_shard + index += 1 if self.grad_clipping: # Update unscale/clip with sub partitions @@ -288,6 +316,7 @@ def step(self, closure=None, **kwargs): self.base_optimizer.zero_grad(set_to_none=True) # All gather the new weights across the ranks and assign them to the full parameters + sharded_data = [] for param_group, sharded_param_group in zip( self.param_groups, self.base_optimizer.param_groups): for param, shard in zip(param_group['params'], @@ -296,13 +325,33 @@ def step(self, closure=None, **kwargs): shard_data = shard.data if param.dtype != self.optimizer_dtype: shard_data = shard_data.to(dtype=param.dtype) - padded_param = xm.all_gather( - shard_data, - dim=0, - pin_layout=self.pin_layout, - groups=self.sharding_groups, - ) - param.data.copy_(padded_param.data[:param.size(0)]) + if self.coalesce_cc: + sharded_data.append(shard_data) + else: + padded_param = xm.all_gather( + shard_data, + dim=0, + pin_layout=self.pin_layout, + groups=self.sharding_groups, + ) + param.data.copy_(padded_param.data[:param.size(0)]) + + if self.coalesce_cc: + padded_params = xm.all_gather( + sharded_data, + dim=0, + pin_layout=self.pin_layout, + groups=self.sharding_groups, + ) + index = 0 + for param_group, sharded_param_group in zip( + self.param_groups, self.base_optimizer.param_groups): + for param, shard in zip(param_group['params'], + sharded_param_group['params']): + if param.grad is not None: + padded_param = padded_params[index] + param.data.copy_(padded_param.data[:param.size(0)]) + index += 1 # sync back self._sync_param_groups(self.base_optimizer.param_groups, self.param_groups) From 46a069afc949b3c03d6eac85a6691e83ab7b7ed0 Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Thu, 7 Dec 2023 04:56:22 +0000 Subject: [PATCH 02/15] Yapf lint fixes --- torch_xla/core/xla_model.py | 120 ++++++++++-------- .../distributed/zero_redundancy_optimizer.py | 28 ++-- 2 files changed, 82 insertions(+), 66 deletions(-) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 0667ef74b7b..0aaadc05865 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -608,29 +608,32 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): # Now the input should be a list of Tensors. elif isinstance(value, list) and all( isinstance(v, torch.Tensor) for v in value): + # sanity checks if pin_layout: raise RuntimeError( "For xm.all_gather with list of tensors input, pin_layout=True is not yet supported." ) + if output != None: + if not isinstance(output, list) or any( + not isinstance(v, torch.Tensor) for v in output): + raise TypeError( + f"`output` needs to be a list of Tensors, but given {type(output)}." + ) + if len(output) != len(input): + raise ValueError("`output` length doesn't match `input` length: " + f"{len(output)} vs {len(input)}.") + + # helper function for bucketing def _all_gather_coalesced(tensor_list, output_list=None): if output_list != None: - if not isinstance(output_list, list) or any( - not isinstance(v, torch.Tensor) for v in output_list): - raise TypeError( - f"`output` needs to be a list of Tensors, but given {type(output_list)}." - ) - if len(output_list) != len(tensor_list): - raise ValueError("`output` length doesn't match `input` length: " - f"{len(output_list)} vs {len(tensor_list)}.") # Call the out of place version of the reduce_scatter new_token = torch_xla._XLAC._xla_all_gather_coalesced_out( output_list, tensor_list, token, dim, shard_count, groups or [], pin_layout) torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) return output_list - result = torch_xla._XLAC._xla_all_gather_coalesced(tensor_list, token, dim, - shard_count, groups or - [], pin_layout) + result = torch_xla._XLAC._xla_all_gather_coalesced( + tensor_list, token, dim, shard_count, groups or [], pin_layout) torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) return result[:-1] @@ -638,10 +641,9 @@ def _all_gather_coalesced(tensor_list, output_list=None): tensor_bucket = [] output_bucket = [] out_tensors = [] - bucket_cap = int(os.getenv( - "ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB", - _ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB - )) * 1024 * 1024 + bucket_cap = int( + os.getenv("ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB", + _ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB)) * 1024 * 1024 divisor = len(groups[0]) if type(groups[0]) == list else len(groups) bucket_cap = bucket_cap / divisor for idx, tensor in enumerate(value): @@ -649,20 +651,21 @@ def _all_gather_coalesced(tensor_list, output_list=None): tensor_bytes = tensor.numel() * tensor.element_size() output_selected = None if output != None: - output_selected = [output[idx]] + output_selected = output[idx] if tensor.numel() != output_selected.numel(): - raise ValueError(f"`output` tensor size doesn't match `input` tensor size for list index {idx}: " + raise ValueError(f"`output` tensor size doesn't match `input` tensor size for tensor list index {idx}: " f"{output_selected.numel() vs tensor.numel().") # Tensor is larger than bucket_cap, don't bucketize if tensor_bytes > bucket_cap: # Flush out previous buckets even if they don't fill up - if total >= 0.5*bucket_cap or (total + tensor_bytes) > 2*bucket_cap: + if total >= 0.5 * bucket_cap or (total + tensor_bytes) > 2 * bucket_cap: out_tensors.extend(_all_gather_coalesced(tensor_bucket, output_bucket)) - out_tensors.extend(_all_gather_coalesced([tensor], output_selected)) + out_tensors.extend(_all_gather_coalesced([tensor], [output_selected] if output else [])) else: tensor_bucket.append(tensor) - output_bucket.append(output_selected) + if output != None: + output_bucket.append(output[i]) out_tensors.extend(_all_gather_coalesced(tensor_bucket, output_bucket)) total = 0 tensor_bucket = [] @@ -677,7 +680,8 @@ def _all_gather_coalesced(tensor_list, output_list=None): tensor_bucket = [] output_bucket = [] tensor_bucket.append(tensor) - output_bucket.append(output_selected) + if output != None: + output_bucket.append(output_selected) # Flush the last remaining bucket if len(tensor_bucket): @@ -690,6 +694,7 @@ def _all_gather_coalesced(tensor_list, output_list=None): raise TypeError("`value` needs to be a Tensor or a list of Tensors, but " f"given {type(value)}.") + def all_to_all(value, split_dimension, concat_dimension, @@ -877,6 +882,7 @@ def reduce_scatter(reduce_type, # Now the input should be a list of Tensors. elif isinstance(input, list) and all( isinstance(v, torch.Tensor) for v in input): + # sanity checks if output != None: if not isinstance(output, list) or any( not isinstance(v, torch.Tensor) for v in output): @@ -886,69 +892,79 @@ def reduce_scatter(reduce_type, if len(output) != len(input): raise ValueError("`output` length doesn't match `input` length: " f"{len(output)} vs {len(input)}.") - # Call the out of place version of the reduce_scatter - new_token = torch_xla._XLAC._xla_reduce_scatter_coalesced_out( - reduce_type, output, input, token, scale, scatter_dim, shard_count, - groups or [], pin_layout) - torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) - return output -<<<<<<< HEAD - result = torch_xla._XLAC._xla_reduce_scatter_coalesced( - reduce_type, input, token, scale, scatter_dim, shard_count, groups or - [], pin_layout) - torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) - return result[:-1] -======= - def _reduce_scatter_coalesced(tensor_list, out_tensor_bucket): + # helper function for bucketing + def _reduce_scatter_coalesced(tensor_list, output_list=None): + if output_list != None: + # Call the out of place version of the reduce_scatter + new_token = torch_xla._XLAC._xla_reduce_scatter_coalesced_out( + reduce_type, output_list, tensor_list, token, scale, scatter_dim, shard_count, + groups or [], pin_layout) + torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) + return output_list + result = torch_xla._XLAC._xla_reduce_scatter_coalesced( - reduce_type, out_tensor_bucket, tensor_list, token, scale, scatter_dim, shard_count, - groups or [], pin_layout) + reduce_type, tensor_list, token, scale, + scatter_dim, shard_count, groups or [], pin_layout) torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) return result[:-1] total = 0 tensor_bucket = [] - out_tensor_bucket = [] + output_bucket = [] out_tensors = [] - bucket_cap = int(os.getenv( - "ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB", - _ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB - )) * 1024 * 1024 + bucket_cap = int( + os.getenv("ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB", + _ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB)) * 1024 * 1024 for i, tensor in enumerate(input): tensor_bytes = tensor.numel() * tensor.element_size() + output_selected = None + if output != None: + output_selected = output[idx] + if tensor.numel() != output_selected.numel(): + raise ValueError(f"`output` tensor size doesn't match `input` tensor size for tensor list index {idx}: " + f"{output_selected.numel() vs tensor.numel().") # Tensor is larger than bucket_cap, don't bucketize if tensor_bytes > bucket_cap: # Flush out previous buckets even if they don't fill up - if total >= 0.5*bucket_cap or (total + tensor_bytes) > 2*bucket_cap: - out_tensors.extend(_reduce_scatter_coalesced(tensor_bucket, out_tensor_bucket)) - out_tensors.extend(_reduce_scatter_coalesced([tensor], [output[i]] if output else [])) + if total >= 0.5 * bucket_cap or (total + tensor_bytes) > 2 * bucket_cap: + out_tensors.extend( + _reduce_scatter_coalesced(tensor_bucket, output_bucket)) + out_tensors.extend( + _reduce_scatter_coalesced([tensor], + [output_selected] if output else [])) else: tensor_bucket.append(tensor) if output != None: - out_tensor_bucket.append(output[i]) - out_tensors.extend(_reduce_scatter_coalesced(tensor_bucket, out_tensor_bucket)) + output_bucket.append(output[i]) + out_tensors.extend( + _reduce_scatter_coalesced(tensor_bucket, output_bucket)) total = 0 tensor_bucket = [] + output_bucket = [] continue # Bucketize till the total spills over total += tensor_bytes if total > bucket_cap: - out_tensors.extend(_reduce_scatter_coalesced(tensor_bucket, out_tensor_bucket)) + out_tensors.extend( + _reduce_scatter_coalesced(tensor_bucket, output_bucket)) total = tensor_bytes tensor_bucket = [] - out_tensor_bucket = [] + output_bucket = [] tensor_bucket.append(tensor) if output != None: - out_tensor_bucket.append(output[i]) + output_bucket.append(output_selected) # Flush the last remaining bucket if len(tensor_bucket): - out_tensors.extend(_reduce_scatter_coalesced(tensor_bucket, out_tensor_bucket)) + out_tensors.extend( + _reduce_scatter_coalesced(tensor_bucket, output_bucket)) + + assert len(out_tensors) == len(value) + return out_tensors ->>>>>>> add bucketting logic to control the size of tensors for all-gather and reduce-scatter else: raise TypeError("`input` needs to be a Tensor or a list of Tensors, but " f"given {type(input)}.") diff --git a/torch_xla/distributed/zero_redundancy_optimizer.py b/torch_xla/distributed/zero_redundancy_optimizer.py index 7127c923c75..e55f1a84039 100644 --- a/torch_xla/distributed/zero_redundancy_optimizer.py +++ b/torch_xla/distributed/zero_redundancy_optimizer.py @@ -284,14 +284,14 @@ def step(self, closure=None, **kwargs): if self.coalesce_cc: grad_shard = xm.reduce_scatter( - xm.REDUCE_SUM, - padded_grads, - scale=1.0 / self.local_world_size, - scatter_dim=0, - shard_count=self.local_world_size, - pin_layout=self.pin_layout, - groups=self.sharding_groups, - ) + xm.REDUCE_SUM, + padded_grads, + scale=1.0 / self.local_world_size, + scatter_dim=0, + shard_count=self.local_world_size, + pin_layout=self.pin_layout, + groups=self.sharding_groups, + ) index = 0 for param_group, sharded_param_group in zip( self.param_groups, self.base_optimizer.param_groups): @@ -335,14 +335,14 @@ def step(self, closure=None, **kwargs): groups=self.sharding_groups, ) param.data.copy_(padded_param.data[:param.size(0)]) - + if self.coalesce_cc: padded_params = xm.all_gather( - sharded_data, - dim=0, - pin_layout=self.pin_layout, - groups=self.sharding_groups, - ) + sharded_data, + dim=0, + pin_layout=self.pin_layout, + groups=self.sharding_groups, + ) index = 0 for param_group, sharded_param_group in zip( self.param_groups, self.base_optimizer.param_groups): From 8e79997b01794223c048d5454f3addbd3dc77a22 Mon Sep 17 00:00:00 2001 From: Rahul Solanki Date: Tue, 21 Nov 2023 04:00:46 +0000 Subject: [PATCH 03/15] handle the case when groups is none --- torch_xla/core/xla_model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 0aaadc05865..8a629d53aa9 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -644,7 +644,10 @@ def _all_gather_coalesced(tensor_list, output_list=None): bucket_cap = int( os.getenv("ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB", _ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB)) * 1024 * 1024 - divisor = len(groups[0]) if type(groups[0]) == list else len(groups) + if groups: + divisor = len(groups[0]) if type(groups[0]) == list else len(groups) + else: + divisor = xrt_world_size() bucket_cap = bucket_cap / divisor for idx, tensor in enumerate(value): From 5a87467ee0bfd74046e31ed33ca7459e01f704a9 Mon Sep 17 00:00:00 2001 From: guangtai Date: Mon, 20 Nov 2023 13:18:31 -0800 Subject: [PATCH 04/15] update zero1 --- .../distributed/zero_redundancy_optimizer.py | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/torch_xla/distributed/zero_redundancy_optimizer.py b/torch_xla/distributed/zero_redundancy_optimizer.py index e55f1a84039..b3ae95a4132 100644 --- a/torch_xla/distributed/zero_redundancy_optimizer.py +++ b/torch_xla/distributed/zero_redundancy_optimizer.py @@ -79,6 +79,8 @@ def __init__( self.pin_layout = pin_layout self.coalesce_cc = coalesce_cc + self._grad_norm = None + self.inited = False if not lazy_init: self.init_zero() @@ -104,6 +106,10 @@ def init_zero(self): self._sync_param_groups(self.param_groups, self.base_optimizer.param_groups) self.inited = True + @property + def grad_norm(self): + return self._grad_norm + @property def sharding_groups(self): return self._sharding_groups @@ -160,12 +166,17 @@ def _shard_parameters(self): """ Shard all parameters. """ + self.device = None all_params = [] for param_group in self.param_groups: for param in param_group['params']: all_params.append(param) + if self.device is None: + self.device = param.device + else: + assert self.device == param.device, "Params should on the same device." + assert self.device.type == 'xla' - self.device = all_params[0].device xm.unlazy(all_params) sharded_params_groups = [] @@ -229,11 +240,11 @@ def _clip_grad_norm( """ max_norm = float(max_norm) norm_type = float(norm_type) - total_norm = self._calc_grad_norm(norm_type) + self._grad_norm = self._calc_grad_norm(norm_type) clip_coeff = torch.tensor( max_norm, device=self.device) / ( - total_norm + 1e-6) + self._grad_norm + 1e-6) clip_value = torch.where(clip_coeff < 1, clip_coeff, torch.tensor(1., device=self.device)) for param_group in self.base_optimizer.param_groups: @@ -283,7 +294,7 @@ def step(self, closure=None, **kwargs): shard.grad = grad_shard if self.coalesce_cc: - grad_shard = xm.reduce_scatter( + grad_shards = xm.reduce_scatter( xm.REDUCE_SUM, padded_grads, scale=1.0 / self.local_world_size, @@ -362,6 +373,7 @@ def state_dict(self): state_dict = super().state_dict() base_state = self.base_optimizer.state_dict()['state'] state_dict['base_state'] = base_state + state_dict['shape_info'] = self.get_shape_info() return state_dict def load_state_dict(self, state_dict): @@ -375,3 +387,12 @@ def load_state_dict(self, state_dict): tmp = self.base_optimizer.state_dict() tmp['state'] = base_state self.base_optimizer.load_state_dict(tmp) + + def get_shape_info(self): + shape_info = {} + idx = 0 + for param_group in self.param_groups: + for param in param_group['params']: + shape_info[idx] = param.shape + idx += 1 + return shape_info From b354c27784be9dc473ae4bf415ff392ff3c5f2cf Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Wed, 6 Mar 2024 16:45:49 +0000 Subject: [PATCH 05/15] yapf lint fixes --- torch_xla/core/xla_model.py | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 8a629d53aa9..e18de98a16a 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -628,7 +628,8 @@ def _all_gather_coalesced(tensor_list, output_list=None): if output_list != None: # Call the out of place version of the reduce_scatter new_token = torch_xla._XLAC._xla_all_gather_coalesced_out( - output_list, tensor_list, token, dim, shard_count, groups or [], pin_layout) + output_list, tensor_list, token, dim, shard_count, groups or [], + pin_layout) torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) return output_list @@ -650,26 +651,31 @@ def _all_gather_coalesced(tensor_list, output_list=None): divisor = xrt_world_size() bucket_cap = bucket_cap / divisor for idx, tensor in enumerate(value): - + tensor_bytes = tensor.numel() * tensor.element_size() output_selected = None if output != None: output_selected = output[idx] if tensor.numel() != output_selected.numel(): - raise ValueError(f"`output` tensor size doesn't match `input` tensor size for tensor list index {idx}: " - f"{output_selected.numel() vs tensor.numel().") + raise ValueError( + f"`output` tensor size doesn't match `input` tensor size for tensor list index {idx}: " + f"{output_selected.numel() vs tensor.numel().") # Tensor is larger than bucket_cap, don't bucketize if tensor_bytes > bucket_cap: # Flush out previous buckets even if they don't fill up if total >= 0.5 * bucket_cap or (total + tensor_bytes) > 2 * bucket_cap: - out_tensors.extend(_all_gather_coalesced(tensor_bucket, output_bucket)) - out_tensors.extend(_all_gather_coalesced([tensor], [output_selected] if output else [])) + out_tensors.extend( + _all_gather_coalesced(tensor_bucket, output_bucket)) + out_tensors.extend( + _all_gather_coalesced([tensor], + [output_selected] if output else [])) else: tensor_bucket.append(tensor) if output != None: output_bucket.append(output[i]) - out_tensors.extend(_all_gather_coalesced(tensor_bucket, output_bucket)) + out_tensors.extend( + _all_gather_coalesced(tensor_bucket, output_bucket)) total = 0 tensor_bucket = [] output_bucket = [] @@ -901,14 +907,14 @@ def _reduce_scatter_coalesced(tensor_list, output_list=None): if output_list != None: # Call the out of place version of the reduce_scatter new_token = torch_xla._XLAC._xla_reduce_scatter_coalesced_out( - reduce_type, output_list, tensor_list, token, scale, scatter_dim, shard_count, - groups or [], pin_layout) + reduce_type, output_list, tensor_list, token, scale, scatter_dim, + shard_count, groups or [], pin_layout) torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) return output_list result = torch_xla._XLAC._xla_reduce_scatter_coalesced( - reduce_type, tensor_list, token, scale, - scatter_dim, shard_count, groups or [], pin_layout) + reduce_type, tensor_list, token, scale, scatter_dim, shard_count, + groups or [], pin_layout) torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) return result[:-1] @@ -925,8 +931,9 @@ def _reduce_scatter_coalesced(tensor_list, output_list=None): if output != None: output_selected = output[idx] if tensor.numel() != output_selected.numel(): - raise ValueError(f"`output` tensor size doesn't match `input` tensor size for tensor list index {idx}: " - f"{output_selected.numel() vs tensor.numel().") + raise ValueError( + f"`output` tensor size doesn't match `input` tensor size for tensor list index {idx}: " + f"{output_selected.numel() vs tensor.numel().") # Tensor is larger than bucket_cap, don't bucketize if tensor_bytes > bucket_cap: From 22e29d3719be24bf0cb720dd904dafa959ee5a62 Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Mon, 11 Mar 2024 16:30:29 +0000 Subject: [PATCH 06/15] Fix missing curly brackets in assertion msg --- torch_xla/core/xla_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index e18de98a16a..e2f7282f044 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -659,7 +659,7 @@ def _all_gather_coalesced(tensor_list, output_list=None): if tensor.numel() != output_selected.numel(): raise ValueError( f"`output` tensor size doesn't match `input` tensor size for tensor list index {idx}: " - f"{output_selected.numel() vs tensor.numel().") + f"{output_selected.numel()} vs {tensor.numel()}.") # Tensor is larger than bucket_cap, don't bucketize if tensor_bytes > bucket_cap: @@ -933,7 +933,7 @@ def _reduce_scatter_coalesced(tensor_list, output_list=None): if tensor.numel() != output_selected.numel(): raise ValueError( f"`output` tensor size doesn't match `input` tensor size for tensor list index {idx}: " - f"{output_selected.numel() vs tensor.numel().") + f"{output_selected.numel()} vs {tensor.numel()}.") # Tensor is larger than bucket_cap, don't bucketize if tensor_bytes > bucket_cap: From 96c61cd61ac5bf747635d3a5b49f90ce642eb023 Mon Sep 17 00:00:00 2001 From: Amithrajith Mamidala Date: Tue, 30 Jan 2024 18:04:21 +0000 Subject: [PATCH 07/15] Fixing FAL issue when sharded params are initialized with torch.double cr: https://code.amazon.com/reviews/CR-112545987 --- torch_xla/distributed/zero_redundancy_optimizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_xla/distributed/zero_redundancy_optimizer.py b/torch_xla/distributed/zero_redundancy_optimizer.py index b3ae95a4132..31c2d9e58f2 100644 --- a/torch_xla/distributed/zero_redundancy_optimizer.py +++ b/torch_xla/distributed/zero_redundancy_optimizer.py @@ -243,10 +243,10 @@ def _clip_grad_norm( self._grad_norm = self._calc_grad_norm(norm_type) clip_coeff = torch.tensor( - max_norm, device=self.device) / ( + max_norm, device=self.device, dtype=self.optimizer_dtype) / ( self._grad_norm + 1e-6) - clip_value = torch.where(clip_coeff < 1, clip_coeff, - torch.tensor(1., device=self.device)) + clip_value = torch.where(clip_coeff < 1, clip_coeff, + torch.tensor(1., device=self.device, dtype=self.optimizer_dtype)) for param_group in self.base_optimizer.param_groups: for p in param_group['params']: if p.grad is not None: From 6b7ce8fab62eec832374faeac3ac130f8a126efb Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Tue, 12 Mar 2024 04:18:27 +0000 Subject: [PATCH 08/15] Yapf fixes --- torch_xla/distributed/zero_redundancy_optimizer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_xla/distributed/zero_redundancy_optimizer.py b/torch_xla/distributed/zero_redundancy_optimizer.py index 31c2d9e58f2..2d345081e74 100644 --- a/torch_xla/distributed/zero_redundancy_optimizer.py +++ b/torch_xla/distributed/zero_redundancy_optimizer.py @@ -245,8 +245,9 @@ def _clip_grad_norm( clip_coeff = torch.tensor( max_norm, device=self.device, dtype=self.optimizer_dtype) / ( self._grad_norm + 1e-6) - clip_value = torch.where(clip_coeff < 1, clip_coeff, - torch.tensor(1., device=self.device, dtype=self.optimizer_dtype)) + clip_value = torch.where( + clip_coeff < 1, clip_coeff, + torch.tensor(1., device=self.device, dtype=self.optimizer_dtype)) for param_group in self.base_optimizer.param_groups: for p in param_group['params']: if p.grad is not None: From a5de71af73af4022cec500bdfc88b38b6aafbcbf Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Wed, 13 Mar 2024 04:44:52 +0000 Subject: [PATCH 09/15] Fix indices and variable names --- torch_xla/core/xla_model.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index e2f7282f044..fb006bd64d3 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -619,9 +619,9 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): raise TypeError( f"`output` needs to be a list of Tensors, but given {type(output)}." ) - if len(output) != len(input): - raise ValueError("`output` length doesn't match `input` length: " - f"{len(output)} vs {len(input)}.") + if len(output) != len(value): + raise ValueError("`output` length doesn't match `value` length: " + f"{len(output)} vs {len(value)}.") # helper function for bucketing def _all_gather_coalesced(tensor_list, output_list=None): @@ -650,8 +650,8 @@ def _all_gather_coalesced(tensor_list, output_list=None): else: divisor = xrt_world_size() bucket_cap = bucket_cap / divisor - for idx, tensor in enumerate(value): + for idx, tensor in enumerate(value): tensor_bytes = tensor.numel() * tensor.element_size() output_selected = None if output != None: @@ -673,7 +673,7 @@ def _all_gather_coalesced(tensor_list, output_list=None): else: tensor_bucket.append(tensor) if output != None: - output_bucket.append(output[i]) + output_bucket.append(output[idx]) out_tensors.extend( _all_gather_coalesced(tensor_bucket, output_bucket)) total = 0 @@ -925,7 +925,7 @@ def _reduce_scatter_coalesced(tensor_list, output_list=None): bucket_cap = int( os.getenv("ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB", _ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB)) * 1024 * 1024 - for i, tensor in enumerate(input): + for idx, tensor in enumerate(input): tensor_bytes = tensor.numel() * tensor.element_size() output_selected = None if output != None: @@ -947,7 +947,7 @@ def _reduce_scatter_coalesced(tensor_list, output_list=None): else: tensor_bucket.append(tensor) if output != None: - output_bucket.append(output[i]) + output_bucket.append(output[idx]) out_tensors.extend( _reduce_scatter_coalesced(tensor_bucket, output_bucket)) total = 0 From 77b2ad17e3ee0019078deb68a48cae1f891c346b Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Thu, 14 Mar 2024 21:45:12 +0000 Subject: [PATCH 10/15] Checking of .numel for output tensors cause error in GPU runtime --- torch_xla/core/xla_model.py | 45 +++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index fb006bd64d3..093fb636d52 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -624,8 +624,12 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): f"{len(output)} vs {len(value)}.") # helper function for bucketing - def _all_gather_coalesced(tensor_list, output_list=None): - if output_list != None: + def _all_gather_coalesced(tensor_list, output_list=[]): + if output_list != []: + if len(output_list) != len(tensor_list): + raise ValueError( + "_all_gather_coalesced: `output_list` length doesn't match `tensor_list` length: " + f"{len(output_list)} vs {len(tensor_list)}.") # Call the out of place version of the reduce_scatter new_token = torch_xla._XLAC._xla_all_gather_coalesced_out( output_list, tensor_list, token, dim, shard_count, groups or [], @@ -638,6 +642,8 @@ def _all_gather_coalesced(tensor_list, output_list=None): torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) return result[:-1] + #return _all_gather_coalesced(value, output if output else []) + total = 0 tensor_bucket = [] output_bucket = [] @@ -653,13 +659,6 @@ def _all_gather_coalesced(tensor_list, output_list=None): for idx, tensor in enumerate(value): tensor_bytes = tensor.numel() * tensor.element_size() - output_selected = None - if output != None: - output_selected = output[idx] - if tensor.numel() != output_selected.numel(): - raise ValueError( - f"`output` tensor size doesn't match `input` tensor size for tensor list index {idx}: " - f"{output_selected.numel()} vs {tensor.numel()}.") # Tensor is larger than bucket_cap, don't bucketize if tensor_bytes > bucket_cap: @@ -668,8 +667,7 @@ def _all_gather_coalesced(tensor_list, output_list=None): out_tensors.extend( _all_gather_coalesced(tensor_bucket, output_bucket)) out_tensors.extend( - _all_gather_coalesced([tensor], - [output_selected] if output else [])) + _all_gather_coalesced([tensor], [output[idx]] if output else [])) else: tensor_bucket.append(tensor) if output != None: @@ -690,7 +688,7 @@ def _all_gather_coalesced(tensor_list, output_list=None): output_bucket = [] tensor_bucket.append(tensor) if output != None: - output_bucket.append(output_selected) + output_bucket.append(output[idx]) # Flush the last remaining bucket if len(tensor_bucket): @@ -901,10 +899,13 @@ def reduce_scatter(reduce_type, if len(output) != len(input): raise ValueError("`output` length doesn't match `input` length: " f"{len(output)} vs {len(input)}.") - # helper function for bucketing - def _reduce_scatter_coalesced(tensor_list, output_list=None): - if output_list != None: + def _reduce_scatter_coalesced(tensor_list, output_list=[]): + if output_list != []: + if len(output_list) != len(tensor_list): + raise ValueError( + "_reduce_scatter_coalesced: `output_list` length doesn't match `tensor_list` length: " + f"{len(output_list)} vs {len(tensor_list)}.") # Call the out of place version of the reduce_scatter new_token = torch_xla._XLAC._xla_reduce_scatter_coalesced_out( reduce_type, output_list, tensor_list, token, scale, scatter_dim, @@ -927,13 +928,6 @@ def _reduce_scatter_coalesced(tensor_list, output_list=None): _ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB)) * 1024 * 1024 for idx, tensor in enumerate(input): tensor_bytes = tensor.numel() * tensor.element_size() - output_selected = None - if output != None: - output_selected = output[idx] - if tensor.numel() != output_selected.numel(): - raise ValueError( - f"`output` tensor size doesn't match `input` tensor size for tensor list index {idx}: " - f"{output_selected.numel()} vs {tensor.numel()}.") # Tensor is larger than bucket_cap, don't bucketize if tensor_bytes > bucket_cap: @@ -943,10 +937,11 @@ def _reduce_scatter_coalesced(tensor_list, output_list=None): _reduce_scatter_coalesced(tensor_bucket, output_bucket)) out_tensors.extend( _reduce_scatter_coalesced([tensor], - [output_selected] if output else [])) + [output[idx]] if output else [])) else: tensor_bucket.append(tensor) if output != None: + assert (output[idx] != None) output_bucket.append(output[idx]) out_tensors.extend( _reduce_scatter_coalesced(tensor_bucket, output_bucket)) @@ -965,14 +960,14 @@ def _reduce_scatter_coalesced(tensor_list, output_list=None): output_bucket = [] tensor_bucket.append(tensor) if output != None: - output_bucket.append(output_selected) + output_bucket.append(output[idx]) # Flush the last remaining bucket if len(tensor_bucket): out_tensors.extend( _reduce_scatter_coalesced(tensor_bucket, output_bucket)) - assert len(out_tensors) == len(value) + assert len(out_tensors) == len(input) return out_tensors else: From ae348b24ffb7152145f564e0f7dd9346a4e6fa58 Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Tue, 19 Mar 2024 05:46:08 +0000 Subject: [PATCH 11/15] Avoid passing empty input buckets --- torch_xla/core/xla_model.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 093fb636d52..f1f7051dc46 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -664,8 +664,9 @@ def _all_gather_coalesced(tensor_list, output_list=[]): if tensor_bytes > bucket_cap: # Flush out previous buckets even if they don't fill up if total >= 0.5 * bucket_cap or (total + tensor_bytes) > 2 * bucket_cap: - out_tensors.extend( - _all_gather_coalesced(tensor_bucket, output_bucket)) + if len(tensor_bucket): + out_tensors.extend( + _all_gather_coalesced(tensor_bucket, output_bucket)) out_tensors.extend( _all_gather_coalesced([tensor], [output[idx]] if output else [])) else: @@ -682,7 +683,9 @@ def _all_gather_coalesced(tensor_list, output_list=[]): # Bucketize till the total spills over total += tensor_bytes if total > bucket_cap: - out_tensors.extend(_all_gather_coalesced(tensor_bucket, output_bucket)) + if len(tensor_bucket): + out_tensors.extend( + _all_gather_coalesced(tensor_bucket, output_bucket)) total = tensor_bytes tensor_bucket = [] output_bucket = [] @@ -933,15 +936,15 @@ def _reduce_scatter_coalesced(tensor_list, output_list=[]): if tensor_bytes > bucket_cap: # Flush out previous buckets even if they don't fill up if total >= 0.5 * bucket_cap or (total + tensor_bytes) > 2 * bucket_cap: - out_tensors.extend( - _reduce_scatter_coalesced(tensor_bucket, output_bucket)) + if len(tensor_bucket): + out_tensors.extend( + _reduce_scatter_coalesced(tensor_bucket, output_bucket)) out_tensors.extend( _reduce_scatter_coalesced([tensor], [output[idx]] if output else [])) else: tensor_bucket.append(tensor) if output != None: - assert (output[idx] != None) output_bucket.append(output[idx]) out_tensors.extend( _reduce_scatter_coalesced(tensor_bucket, output_bucket)) @@ -953,8 +956,9 @@ def _reduce_scatter_coalesced(tensor_list, output_list=[]): # Bucketize till the total spills over total += tensor_bytes if total > bucket_cap: - out_tensors.extend( - _reduce_scatter_coalesced(tensor_bucket, output_bucket)) + if len(tensor_bucket): + out_tensors.extend( + _reduce_scatter_coalesced(tensor_bucket, output_bucket)) total = tensor_bytes tensor_bucket = [] output_bucket = [] From 85863703544b8efc7ff4027c8e94e0c132c1ce91 Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Tue, 19 Mar 2024 20:15:33 +0000 Subject: [PATCH 12/15] Fix indent for 2 lines in ZeRO1 (shard.grad = grad_shard, index += 1) --- torch_xla/distributed/zero_redundancy_optimizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/distributed/zero_redundancy_optimizer.py b/torch_xla/distributed/zero_redundancy_optimizer.py index 2d345081e74..c46642fa448 100644 --- a/torch_xla/distributed/zero_redundancy_optimizer.py +++ b/torch_xla/distributed/zero_redundancy_optimizer.py @@ -314,8 +314,8 @@ def step(self, closure=None, **kwargs): if grad_shard.dtype != self.optimizer_dtype: grad_shard = grad_shard.to(dtype=self.optimizer_dtype) - shard.grad = grad_shard - index += 1 + shard.grad = grad_shard + index += 1 if self.grad_clipping: # Update unscale/clip with sub partitions From 675e7a112567ea95e977a9991334042c08177539 Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Wed, 20 Mar 2024 05:48:30 +0000 Subject: [PATCH 13/15] Refactor bucketized all-gather/reduce-scatter functions; add bucket_cap_mb arg --- test/test_mp_all_gather.py | 36 ++ test/test_mp_reduce_scatter.py | 56 +++ torch_xla/core/xla_model.py | 328 ++++++++++-------- .../distributed/zero_redundancy_optimizer.py | 14 +- 4 files changed, 294 insertions(+), 140 deletions(-) diff --git a/test/test_mp_all_gather.py b/test/test_mp_all_gather.py index 5d38158d287..be83b74631e 100644 --- a/test/test_mp_all_gather.py +++ b/test/test_mp_all_gather.py @@ -95,6 +95,42 @@ def _mp_fn(index): file=sys.stderr) print(f'[{index}] {cpu_result}', file=sys.stderr) sys.exit(1) + + # Testing with a single replica group and tensor list as input (Bucketized) + # TODO: add support for list input with pin_layout=True and output=None + result_list = xm.all_gather_bucketized( + ordinal_tensors, dim=0, pin_layout=False) + + for i, result in enumerate(result_list): + cpu_result = result.cpu() + expected = i * 1000 + torch.arange(world_size, dtype=torch.float) + if not cpu_result.allclose(expected): + print( + 'xm.all_gather_bucketized() produced wrong reductions for item {i} in result list', + file=sys.stderr) + print(f'[{index}] {cpu_result}', file=sys.stderr) + sys.exit(1) + + # Testing with a single replica group and tensor list as input and output!=None (out-of-place) (Bucketized) + # Reuse ordinal_tensors from previous test + output_tensors = [ + torch.zeros([world_size], dtype=torch.float).to(device) + for i in range(input_list_size) + ] + # TODO: add support for list input with pin_layout=True and output!=None + result_list = xm.all_gather_bucketized( + ordinal_tensors, dim=0, output=output_tensors, pin_layout=False) + + for i, result in enumerate(result_list): + cpu_result = result.cpu() + expected = i * 1000 + torch.arange(world_size, dtype=torch.float) + if not cpu_result.allclose(expected): + print( + 'xm.all_gather() produced wrong reductions for item {i} in result list', + file=sys.stderr) + print(f'[{index}] {cpu_result}', file=sys.stderr) + sys.exit(1) + # TODO: add test for torch.compile when support for list input is ready else: diff --git a/test/test_mp_reduce_scatter.py b/test/test_mp_reduce_scatter.py index 1ef61d3aa79..94363be1a02 100644 --- a/test/test_mp_reduce_scatter.py +++ b/test/test_mp_reduce_scatter.py @@ -55,6 +55,33 @@ def _mp_fn(index): xm.rendezvous('test_reduce_scatter_list_input') + # Testing reduce-scatter with list input bucketized + rand_list = [ + torch.rand((32, shard_size * world_size, 32)) + for _ in range(input_list_size) + ] + xrand_list = [rand.to(device) for rand in rand_list] + + # TODO: fix the broken case with pin_layout=True + res_list = xm.reduce_scatter_bucketized( + xm.REDUCE_SUM, + xrand_list, + scale, + scatter_dim, + world_size, + pin_layout=False) + + for i, res in enumerate(res_list): + expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale) + xm.mark_step() + + slice_idx = torch.tensor( + list(range(index * shard_size, (index + 1) * shard_size))) + expected = expected_world.cpu().index_select(scatter_dim, slice_idx) + assert res.cpu().allclose(expected) + + xm.rendezvous('test_reduce_scatter_list_input_bucketized') + # Testing reduce-scatter with list input and output output_list = [ torch.rand((32, shard_size * world_size, 32)) @@ -83,6 +110,35 @@ def _mp_fn(index): assert res.cpu().allclose(expected) xm.rendezvous('test_reduce_scatter_list_input_output') + + # Testing reduce-scatter with list input and output + output_list = [ + torch.rand((32, shard_size * world_size, 32)) + for _ in range(input_list_size) + ] + xoutput_list = [output.to(device) for output in output_list] + + # TODO: fix the broken case with pin_layout=True + res_list = xm.reduce_scatter_bucketized( + xm.REDUCE_SUM, + xrand_list, + scale, + scatter_dim, + world_size, + output=xoutput_list, + pin_layout=False) + + assert (xoutput_list == res_list) + for i, res in enumerate(xoutput_list): + expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale) + xm.mark_step() + + slice_idx = torch.tensor( + list(range(index * shard_size, (index + 1) * shard_size))) + expected = expected_world.cpu().index_select(scatter_dim, slice_idx) + assert res.cpu().allclose(expected) + + xm.rendezvous('test_reduce_scatter_list_input_output_bucketized') else: print( 'Default device {} is not a TPU device'.format(device), file=sys.stderr) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index f1f7051dc46..f5982468f2b 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -1,7 +1,6 @@ import io import itertools import logging -import os import sys import re import threading @@ -608,7 +607,6 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): # Now the input should be a list of Tensors. elif isinstance(value, list) and all( isinstance(v, torch.Tensor) for v in value): - # sanity checks if pin_layout: raise RuntimeError( "For xm.all_gather with list of tensors input, pin_layout=True is not yet supported." @@ -620,89 +618,113 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): f"`output` needs to be a list of Tensors, but given {type(output)}." ) if len(output) != len(value): - raise ValueError("`output` length doesn't match `value` length: " - f"{len(output)} vs {len(value)}.") - - # helper function for bucketing - def _all_gather_coalesced(tensor_list, output_list=[]): - if output_list != []: - if len(output_list) != len(tensor_list): - raise ValueError( - "_all_gather_coalesced: `output_list` length doesn't match `tensor_list` length: " - f"{len(output_list)} vs {len(tensor_list)}.") - # Call the out of place version of the reduce_scatter - new_token = torch_xla._XLAC._xla_all_gather_coalesced_out( - output_list, tensor_list, token, dim, shard_count, groups or [], - pin_layout) - torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) - return output_list - - result = torch_xla._XLAC._xla_all_gather_coalesced( - tensor_list, token, dim, shard_count, groups or [], pin_layout) - torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) - return result[:-1] - - #return _all_gather_coalesced(value, output if output else []) - - total = 0 - tensor_bucket = [] - output_bucket = [] - out_tensors = [] - bucket_cap = int( - os.getenv("ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB", - _ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB)) * 1024 * 1024 - if groups: - divisor = len(groups[0]) if type(groups[0]) == list else len(groups) - else: - divisor = xrt_world_size() - bucket_cap = bucket_cap / divisor - - for idx, tensor in enumerate(value): - tensor_bytes = tensor.numel() * tensor.element_size() - - # Tensor is larger than bucket_cap, don't bucketize - if tensor_bytes > bucket_cap: - # Flush out previous buckets even if they don't fill up - if total >= 0.5 * bucket_cap or (total + tensor_bytes) > 2 * bucket_cap: - if len(tensor_bucket): - out_tensors.extend( - _all_gather_coalesced(tensor_bucket, output_bucket)) - out_tensors.extend( - _all_gather_coalesced([tensor], [output[idx]] if output else [])) - else: - tensor_bucket.append(tensor) - if output != None: - output_bucket.append(output[idx]) - out_tensors.extend( - _all_gather_coalesced(tensor_bucket, output_bucket)) - total = 0 - tensor_bucket = [] - output_bucket = [] - continue + raise ValueError("`output` length doesn't match `input` length: " + f"{len(output)} vs {len(input)}.") + # Call the out of place version of the reduce_scatter + new_token = torch_xla._XLAC._xla_all_gather_coalesced_out( + output, value, token, dim, shard_count, groups or [], pin_layout) + torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) + return output + result = torch_xla._XLAC._xla_all_gather_coalesced(value, token, dim, + shard_count, groups or + [], pin_layout) + torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) + return result[:-1] + else: + raise TypeError("`value` needs to be a Tensor or a list of Tensors, but " + f"given {type(value)}.") + + +def all_gather_bucketized(input_list, + dim=0, + groups=None, + output=None, + pin_layout=True, + bucket_cap_mb=160): + """Performs an all-gather operation along a given dimension, with bucketization. + + Args: + See all_gather for the args: dim, groups, output, pin_layout + input_list: List of input tensors + bucket_cap_mb: Number of MegaBytes of the tensor bucket to fill before doing all-gather. + + Returns: + A list of tensors each of which has, in the ``dim`` dimension, all the values from the + participating replicas. + """ + # sanity checks + if pin_layout: + raise RuntimeError( + "For xm.all_gather_bucketized, pin_layout=True is not yet supported.") + if not isinstance(input_list, list) or any( + not isinstance(v, torch.Tensor) for v in input_list): + raise TypeError( + f"`input_list` needs to be a list of Tensors, but given {type(input_list)}." + ) + if output != None: + if not isinstance(output, list) or any( + not isinstance(v, torch.Tensor) for v in output): + raise TypeError( + f"`output` needs to be a list of Tensors, but given {type(output)}.") + if len(output) != len(input_list): + raise ValueError("`output` length doesn't match `input_list` length: " + f"{len(output)} vs {len(input_list)}.") + + def _all_gather_coalesced(_input_list, _output_list=None): + return all_gather( + value=_input_list, + dim=dim, + groups=groups, + output=_output_list, + pin_layout=pin_layout) + + total = 0 + tensor_bucket = [] + output_bucket = [] if output else None + out_tensors = [] + bucket_cap = bucket_cap_mb * 1024 * 1024 + if groups: + divisor = len(groups[0]) if type(groups[0]) == list else len(groups) + else: + divisor = xrt_world_size() + bucket_cap = bucket_cap / divisor + + for idx, tensor in enumerate(input_list): + tensor_bytes = tensor.numel() * tensor.element_size() + + # Aim for target bucket_cap_mb: flush new tensor with bucket if bucket content + # is small (1/2 cap) but don't combine if combined total is over 2x cap + total_new = total + tensor_bytes + if tensor_bytes > bucket_cap and total < 0.5 * bucket_cap and total_new <= 2 * bucket_cap: + tensor_bucket.append(tensor) + if output != None: + output_bucket.append(output[idx]) + out_tensors.extend(_all_gather_coalesced(tensor_bucket, output_bucket)) + total = 0 + tensor_bucket = [] + output_bucket = [] if output else None + else: # Bucketize till the total spills over - total += tensor_bytes - if total > bucket_cap: + if total_new > bucket_cap: if len(tensor_bucket): out_tensors.extend( _all_gather_coalesced(tensor_bucket, output_bucket)) - total = tensor_bytes + total = 0 tensor_bucket = [] - output_bucket = [] + output_bucket = [] if output else None + total = total_new tensor_bucket.append(tensor) if output != None: output_bucket.append(output[idx]) - # Flush the last remaining bucket - if len(tensor_bucket): - out_tensors.extend(_all_gather_coalesced(tensor_bucket, output_bucket)) + # Flush the last remaining bucket + if len(tensor_bucket): + out_tensors.extend(_all_gather_coalesced(tensor_bucket, output_bucket)) - assert len(out_tensors) == len(value) + assert len(out_tensors) == len(input_list) - return out_tensors - else: - raise TypeError("`value` needs to be a Tensor or a list of Tensors, but " - f"given {type(value)}.") + return out_tensors def all_to_all(value, @@ -892,7 +914,6 @@ def reduce_scatter(reduce_type, # Now the input should be a list of Tensors. elif isinstance(input, list) and all( isinstance(v, torch.Tensor) for v in input): - # sanity checks if output != None: if not isinstance(output, list) or any( not isinstance(v, torch.Tensor) for v in output): @@ -902,81 +923,116 @@ def reduce_scatter(reduce_type, if len(output) != len(input): raise ValueError("`output` length doesn't match `input` length: " f"{len(output)} vs {len(input)}.") - # helper function for bucketing - def _reduce_scatter_coalesced(tensor_list, output_list=[]): - if output_list != []: - if len(output_list) != len(tensor_list): - raise ValueError( - "_reduce_scatter_coalesced: `output_list` length doesn't match `tensor_list` length: " - f"{len(output_list)} vs {len(tensor_list)}.") - # Call the out of place version of the reduce_scatter - new_token = torch_xla._XLAC._xla_reduce_scatter_coalesced_out( - reduce_type, output_list, tensor_list, token, scale, scatter_dim, - shard_count, groups or [], pin_layout) - torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) - return output_list - - result = torch_xla._XLAC._xla_reduce_scatter_coalesced( - reduce_type, tensor_list, token, scale, scatter_dim, shard_count, + # Call the out of place version of the reduce_scatter + new_token = torch_xla._XLAC._xla_reduce_scatter_coalesced_out( + reduce_type, output, input, token, scale, scatter_dim, shard_count, groups or [], pin_layout) - torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) - return result[:-1] - - total = 0 - tensor_bucket = [] - output_bucket = [] - out_tensors = [] - bucket_cap = int( - os.getenv("ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB", - _ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB)) * 1024 * 1024 - for idx, tensor in enumerate(input): - tensor_bytes = tensor.numel() * tensor.element_size() - - # Tensor is larger than bucket_cap, don't bucketize - if tensor_bytes > bucket_cap: - # Flush out previous buckets even if they don't fill up - if total >= 0.5 * bucket_cap or (total + tensor_bytes) > 2 * bucket_cap: - if len(tensor_bucket): - out_tensors.extend( - _reduce_scatter_coalesced(tensor_bucket, output_bucket)) - out_tensors.extend( - _reduce_scatter_coalesced([tensor], - [output[idx]] if output else [])) - else: - tensor_bucket.append(tensor) - if output != None: - output_bucket.append(output[idx]) - out_tensors.extend( - _reduce_scatter_coalesced(tensor_bucket, output_bucket)) - total = 0 - tensor_bucket = [] - output_bucket = [] - continue + torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) + return output + result = torch_xla._XLAC._xla_reduce_scatter_coalesced( + reduce_type, input, token, scale, scatter_dim, shard_count, groups or + [], pin_layout) + torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) + return result[:-1] + else: + raise TypeError("`input` needs to be a Tensor or a list of Tensors, but " + f"given {type(input)}.") + + +def reduce_scatter_bucketized(reduce_type, + input_list, + scale, + scatter_dim, + shard_count, + groups=None, + output=None, + pin_layout=True, + bucket_cap_mb=160): + """Performs a XLA `ReduceScatter()` operation on a list of tensors (bucketized). + + See: https://www.tensorflow.org/xla/operation_semantics#reducescatter + + Args: + see reduce_scatter for reduce_type, scale, scatter_dim, shard_count, groups, pin_layout + input_list: List of input tensors + output: Optional list of output torch.Tensor + bucket_cap_mb: Number of MegaBytes of the tensor bucket to fill before doing all-gather. + + Returns: + A list of `torch.Tensors` with all the values reduced across replicas. Each process + gets a shard split along the `scatter_dim`. All other dimensions are + the same as the input. + """ + token, devctx = _get_all_reduce_token() + + if not isinstance(input_list, list) or any( + not isinstance(v, torch.Tensor) for v in input_list): + raise TypeError( + f"`input_list` needs to be a list of Tensors, but given {type(input_list)}." + ) + if output != None: + if not isinstance(output, list) or any( + not isinstance(v, torch.Tensor) for v in output): + raise TypeError( + f"`output` needs to be a list of Tensors, but given {type(output)}.") + if len(output) != len(input_list): + raise ValueError("`output` length doesn't match `input_list` length: " + f"{len(output)} vs {len(input_list)}.") + + def _reduce_scatter_coalesced(_input_list, _output_list=None): + return reduce_scatter( + reduce_type=reduce_type, + input=_input_list, + scale=scale, + scatter_dim=scatter_dim, + shard_count=shard_count, + groups=groups, + output=_output_list, + pin_layout=pin_layout) + + total = 0 + tensor_bucket = [] + output_bucket = [] if output else None + out_tensors = [] + bucket_cap = bucket_cap_mb * 1024 * 1024 + + for idx, tensor in enumerate(input_list): + tensor_bytes = tensor.numel() * tensor.element_size() + + # Aim for target bucket_cap_mb: flush new tensor with bucket if bucket content + # is small (1/2 cap) but don't combine if combined total is over 2x cap + total_new = total + tensor_bytes + if tensor_bytes > bucket_cap and total < 0.5 * bucket_cap and total_new <= 2 * bucket_cap: + tensor_bucket.append(tensor) + if output != None: + output_bucket.append(output[idx]) + out_tensors.extend( + _reduce_scatter_coalesced(tensor_bucket, output_bucket)) + total = 0 + tensor_bucket = [] + output_bucket = [] if output else None + else: # Bucketize till the total spills over - total += tensor_bytes - if total > bucket_cap: + if total_new > bucket_cap: if len(tensor_bucket): out_tensors.extend( _reduce_scatter_coalesced(tensor_bucket, output_bucket)) - total = tensor_bytes + total = 0 tensor_bucket = [] - output_bucket = [] + output_bucket = [] if output else None + total = total_new tensor_bucket.append(tensor) if output != None: output_bucket.append(output[idx]) - # Flush the last remaining bucket - if len(tensor_bucket): - out_tensors.extend( - _reduce_scatter_coalesced(tensor_bucket, output_bucket)) + # Flush the last remaining bucket + if len(tensor_bucket): + out_tensors.extend(_reduce_scatter_coalesced(tensor_bucket, output_bucket)) - assert len(out_tensors) == len(input) + assert len(out_tensors) == len(input_list) - return out_tensors - else: - raise TypeError("`input` needs to be a Tensor or a list of Tensors, but " - f"given {type(input)}.") + return out_tensors def add_step_closure(closure, args=(), run_async=False): diff --git a/torch_xla/distributed/zero_redundancy_optimizer.py b/torch_xla/distributed/zero_redundancy_optimizer.py index c46642fa448..71a8742de65 100644 --- a/torch_xla/distributed/zero_redundancy_optimizer.py +++ b/torch_xla/distributed/zero_redundancy_optimizer.py @@ -40,6 +40,9 @@ class ZeroRedundancyOptimizer(Optimizer): If specified, ZeRO-1 will use this ``grad_norm_groups`` for the EXTRA all-reduce op in grad norm calculation. This can be model parallel groups when mixing ZeRO-1 with model parallelism such as Megatron. + bucket_cap_mb: + If non-zero, specifies the maximum number of megabytes to combine tensors + before doing the all-gather/reduce-scatter operations. **defaults: any trailing arguments, which are forwarded to the local optimizer. @@ -60,7 +63,7 @@ def __init__( sharding_groups: Optional[Any] = None, grad_norm_groups: Optional[Any] = None, lazy_init: bool = False, - coalesce_cc: bool = False, + bucket_cap_mb: int = 0, **defaults: Any, ): super().__init__(params, defaults) @@ -77,7 +80,8 @@ def __init__( self.grad_clipping = grad_clipping self.max_norm = max_norm if max_norm is not None else 1.0 self.pin_layout = pin_layout - self.coalesce_cc = coalesce_cc + self.bucket_cap_mb = bucket_cap_mb + self.coalesce_cc = bucket_cap_mb > 0 self._grad_norm = None @@ -295,7 +299,7 @@ def step(self, closure=None, **kwargs): shard.grad = grad_shard if self.coalesce_cc: - grad_shards = xm.reduce_scatter( + grad_shards = xm.reduce_scatter_bucketized( xm.REDUCE_SUM, padded_grads, scale=1.0 / self.local_world_size, @@ -303,6 +307,7 @@ def step(self, closure=None, **kwargs): shard_count=self.local_world_size, pin_layout=self.pin_layout, groups=self.sharding_groups, + bucket_cap_mb=self.bucket_cap_mb, ) index = 0 for param_group, sharded_param_group in zip( @@ -349,11 +354,12 @@ def step(self, closure=None, **kwargs): param.data.copy_(padded_param.data[:param.size(0)]) if self.coalesce_cc: - padded_params = xm.all_gather( + padded_params = xm.all_gather_bucketized( sharded_data, dim=0, pin_layout=self.pin_layout, groups=self.sharding_groups, + bucket_cap_mb=self.bucket_cap_mb, ) index = 0 for param_group, sharded_param_group in zip( From d7c995882feee2a578d23423c72c94c9e95cebfb Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Thu, 21 Mar 2024 05:49:00 +0000 Subject: [PATCH 14/15] Refactor bucketing logic into a class, shared by all-gather/reduce-scatter --- test/test_mp_reduce_scatter.py | 2 +- torch_xla/core/xla_model.py | 208 ++++++++---------- .../distributed/zero_redundancy_optimizer.py | 1 - 3 files changed, 90 insertions(+), 121 deletions(-) mode change 100755 => 100644 torch_xla/core/xla_model.py diff --git a/test/test_mp_reduce_scatter.py b/test/test_mp_reduce_scatter.py index 94363be1a02..ac02b6d2e7c 100644 --- a/test/test_mp_reduce_scatter.py +++ b/test/test_mp_reduce_scatter.py @@ -111,7 +111,7 @@ def _mp_fn(index): xm.rendezvous('test_reduce_scatter_list_input_output') - # Testing reduce-scatter with list input and output + # Testing reduce-scatter with list input and output (buckettized) output_list = [ torch.rand((32, shard_size * world_size, 32)) for _ in range(input_list_size) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py old mode 100755 new mode 100644 index f5982468f2b..dc549db45bf --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -38,9 +38,6 @@ XLA_LIB = Library("xla", "DEF") -# Default bucket size for all-gather and reduce-scatter -_ALL_GATHER_REDUCE_SCATTER_BUCKET_CAP_MB = 160 - def _init_world_size_ordinal(): global _WORLD_SIZE, _ORDINAL @@ -636,6 +633,81 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): f"given {type(value)}.") +class CoalescingBuckets(object): + + def __init__(self, + func, + input_list, + output_list=None, + groups=None, + bucket_cap_mb=160): + if not isinstance(input_list, list) or any( + not isinstance(v, torch.Tensor) for v in input_list): + raise TypeError( + f"`input_list` needs to be a list of Tensors, but given {type(input_list)}." + ) + if output_list != None: + if not isinstance(output_list, list) or any( + not isinstance(v, torch.Tensor) for v in output_list): + raise TypeError( + f"`output_list` needs to be a list of Tensors, but given {type(output_list)}." + ) + if len(output_list) != len(input_list): + raise ValueError( + "`output_list` length doesn't match `input_list` length: " + f"{len(output_list)} vs {len(input_list)}.") + self._func = func + self._input_list = input_list + self._output_list = output_list + self._total = 0 + self._tensor_bucket = [] + self._output_bucket = [] if output_list else None + self._bucket_cap = bucket_cap_mb * 1024 * 1024 + if groups: + divisor = len(groups[0]) if type(groups[0]) == list else len(groups) + else: + divisor = xrt_world_size() + self._bucket_cap = self._bucket_cap / divisor + self._out_tensors = [] + + def flush(self): + if len(self._tensor_bucket): + self._out_tensors.extend( + self._func(self._tensor_bucket, self._output_bucket)) + self._total = 0 + self._tensor_bucket = [] + self._output_bucket = [] if self._output_list else None + + def add(self, tensor, idx): + self._total += tensor.numel() * tensor.element_size() + self._tensor_bucket.append(tensor) + if self._output_list != None: + self._output_bucket.append(self._output_list[idx]) + + def __call__(self): + for idx, tensor in enumerate(self._input_list): + tensor_bytes = tensor.numel() * tensor.element_size() + + # Aim for target bucket_cap_mb: flush new tensor with bucket if bucket content + # is small (1/2 cap) but don't combine if combined total is over 2x cap + total_new = self._total + tensor_bytes + if tensor_bytes > self._bucket_cap and self._total < 0.5 * self._bucket_cap and total_new <= 2 * self._bucket_cap: + self.add(tensor, idx) + self.flush() + else: + # Bucketize till the total spills over + if total_new > self._bucket_cap: + self.flush() + self.add(tensor, idx) + + # Flush the last remaining bucket + self.flush() + + assert len(self._out_tensors) == len(self._input_list) + + return self._out_tensors + + def all_gather_bucketized(input_list, dim=0, groups=None, @@ -657,19 +729,6 @@ def all_gather_bucketized(input_list, if pin_layout: raise RuntimeError( "For xm.all_gather_bucketized, pin_layout=True is not yet supported.") - if not isinstance(input_list, list) or any( - not isinstance(v, torch.Tensor) for v in input_list): - raise TypeError( - f"`input_list` needs to be a list of Tensors, but given {type(input_list)}." - ) - if output != None: - if not isinstance(output, list) or any( - not isinstance(v, torch.Tensor) for v in output): - raise TypeError( - f"`output` needs to be a list of Tensors, but given {type(output)}.") - if len(output) != len(input_list): - raise ValueError("`output` length doesn't match `input_list` length: " - f"{len(output)} vs {len(input_list)}.") def _all_gather_coalesced(_input_list, _output_list=None): return all_gather( @@ -679,52 +738,13 @@ def _all_gather_coalesced(_input_list, _output_list=None): output=_output_list, pin_layout=pin_layout) - total = 0 - tensor_bucket = [] - output_bucket = [] if output else None - out_tensors = [] - bucket_cap = bucket_cap_mb * 1024 * 1024 - if groups: - divisor = len(groups[0]) if type(groups[0]) == list else len(groups) - else: - divisor = xrt_world_size() - bucket_cap = bucket_cap / divisor - - for idx, tensor in enumerate(input_list): - tensor_bytes = tensor.numel() * tensor.element_size() - - # Aim for target bucket_cap_mb: flush new tensor with bucket if bucket content - # is small (1/2 cap) but don't combine if combined total is over 2x cap - total_new = total + tensor_bytes - if tensor_bytes > bucket_cap and total < 0.5 * bucket_cap and total_new <= 2 * bucket_cap: - tensor_bucket.append(tensor) - if output != None: - output_bucket.append(output[idx]) - out_tensors.extend(_all_gather_coalesced(tensor_bucket, output_bucket)) - total = 0 - tensor_bucket = [] - output_bucket = [] if output else None - else: - # Bucketize till the total spills over - if total_new > bucket_cap: - if len(tensor_bucket): - out_tensors.extend( - _all_gather_coalesced(tensor_bucket, output_bucket)) - total = 0 - tensor_bucket = [] - output_bucket = [] if output else None - total = total_new - tensor_bucket.append(tensor) - if output != None: - output_bucket.append(output[idx]) - - # Flush the last remaining bucket - if len(tensor_bucket): - out_tensors.extend(_all_gather_coalesced(tensor_bucket, output_bucket)) - - assert len(out_tensors) == len(input_list) - - return out_tensors + buckets = CoalescingBuckets( + _all_gather_coalesced, + input_list, + output, + groups=groups, + bucket_cap_mb=bucket_cap_mb) + return buckets() def all_to_all(value, @@ -964,21 +984,6 @@ def reduce_scatter_bucketized(reduce_type, gets a shard split along the `scatter_dim`. All other dimensions are the same as the input. """ - token, devctx = _get_all_reduce_token() - - if not isinstance(input_list, list) or any( - not isinstance(v, torch.Tensor) for v in input_list): - raise TypeError( - f"`input_list` needs to be a list of Tensors, but given {type(input_list)}." - ) - if output != None: - if not isinstance(output, list) or any( - not isinstance(v, torch.Tensor) for v in output): - raise TypeError( - f"`output` needs to be a list of Tensors, but given {type(output)}.") - if len(output) != len(input_list): - raise ValueError("`output` length doesn't match `input_list` length: " - f"{len(output)} vs {len(input_list)}.") def _reduce_scatter_coalesced(_input_list, _output_list=None): return reduce_scatter( @@ -991,48 +996,13 @@ def _reduce_scatter_coalesced(_input_list, _output_list=None): output=_output_list, pin_layout=pin_layout) - total = 0 - tensor_bucket = [] - output_bucket = [] if output else None - out_tensors = [] - bucket_cap = bucket_cap_mb * 1024 * 1024 - - for idx, tensor in enumerate(input_list): - tensor_bytes = tensor.numel() * tensor.element_size() - - # Aim for target bucket_cap_mb: flush new tensor with bucket if bucket content - # is small (1/2 cap) but don't combine if combined total is over 2x cap - total_new = total + tensor_bytes - if tensor_bytes > bucket_cap and total < 0.5 * bucket_cap and total_new <= 2 * bucket_cap: - tensor_bucket.append(tensor) - if output != None: - output_bucket.append(output[idx]) - out_tensors.extend( - _reduce_scatter_coalesced(tensor_bucket, output_bucket)) - total = 0 - tensor_bucket = [] - output_bucket = [] if output else None - else: - # Bucketize till the total spills over - if total_new > bucket_cap: - if len(tensor_bucket): - out_tensors.extend( - _reduce_scatter_coalesced(tensor_bucket, output_bucket)) - total = 0 - tensor_bucket = [] - output_bucket = [] if output else None - total = total_new - tensor_bucket.append(tensor) - if output != None: - output_bucket.append(output[idx]) - - # Flush the last remaining bucket - if len(tensor_bucket): - out_tensors.extend(_reduce_scatter_coalesced(tensor_bucket, output_bucket)) - - assert len(out_tensors) == len(input_list) - - return out_tensors + buckets = CoalescingBuckets( + _reduce_scatter_coalesced, + input_list, + output, + groups=groups, + bucket_cap_mb=bucket_cap_mb) + return buckets() def add_step_closure(closure, args=(), run_async=False): diff --git a/torch_xla/distributed/zero_redundancy_optimizer.py b/torch_xla/distributed/zero_redundancy_optimizer.py index 71a8742de65..03c16e3e5bd 100644 --- a/torch_xla/distributed/zero_redundancy_optimizer.py +++ b/torch_xla/distributed/zero_redundancy_optimizer.py @@ -316,7 +316,6 @@ def step(self, closure=None, **kwargs): sharded_param_group['params']): if param.grad is not None: grad_shard = grad_shards[index] - if grad_shard.dtype != self.optimizer_dtype: grad_shard = grad_shard.to(dtype=self.optimizer_dtype) shard.grad = grad_shard From 5006388fc40e75bf3bfb926d456e2d04c2503ffd Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Thu, 21 Mar 2024 22:34:28 +0000 Subject: [PATCH 15/15] Remove bucket-cap division logic; separate bucket cap for allgather/reducescatter --- test/test_mp_all_gather.py | 24 ++++++++++++++ test/test_mp_reduce_scatter.py | 31 +++++++++++++++++++ torch_xla/core/xla_model.py | 29 ++++++----------- .../distributed/zero_redundancy_optimizer.py | 21 +++++++------ 4 files changed, 76 insertions(+), 29 deletions(-) diff --git a/test/test_mp_all_gather.py b/test/test_mp_all_gather.py index be83b74631e..14f46043175 100644 --- a/test/test_mp_all_gather.py +++ b/test/test_mp_all_gather.py @@ -131,6 +131,30 @@ def _mp_fn(index): print(f'[{index}] {cpu_result}', file=sys.stderr) sys.exit(1) + # Testing with a single replica group and tensor list as input and output!=None (out-of-place) (Bucketized, zero bucket size) + # Reuse ordinal_tensors from previous test + output_tensors = [ + torch.zeros([world_size], dtype=torch.float).to(device) + for i in range(input_list_size) + ] + # TODO: add support for list input with pin_layout=True and output!=None + result_list = xm.all_gather_bucketized( + ordinal_tensors, + dim=0, + output=output_tensors, + pin_layout=False, + bucket_cap_mb=0) + + for i, result in enumerate(result_list): + cpu_result = result.cpu() + expected = i * 1000 + torch.arange(world_size, dtype=torch.float) + if not cpu_result.allclose(expected): + print( + 'xm.all_gather() produced wrong reductions for item {i} in result list', + file=sys.stderr) + print(f'[{index}] {cpu_result}', file=sys.stderr) + sys.exit(1) + # TODO: add test for torch.compile when support for list input is ready else: diff --git a/test/test_mp_reduce_scatter.py b/test/test_mp_reduce_scatter.py index ac02b6d2e7c..07173926091 100644 --- a/test/test_mp_reduce_scatter.py +++ b/test/test_mp_reduce_scatter.py @@ -139,6 +139,37 @@ def _mp_fn(index): assert res.cpu().allclose(expected) xm.rendezvous('test_reduce_scatter_list_input_output_bucketized') + + # Testing reduce-scatter with list input and output (buckettized, but zero bucket size) + output_list = [ + torch.rand((32, shard_size * world_size, 32)) + for _ in range(input_list_size) + ] + xoutput_list = [output.to(device) for output in output_list] + + # TODO: fix the broken case with pin_layout=True + res_list = xm.reduce_scatter_bucketized( + xm.REDUCE_SUM, + xrand_list, + scale, + scatter_dim, + world_size, + output=xoutput_list, + bucket_cap_mb=0, + pin_layout=False) + + assert (xoutput_list == res_list) + for i, res in enumerate(xoutput_list): + expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale) + xm.mark_step() + + slice_idx = torch.tensor( + list(range(index * shard_size, (index + 1) * shard_size))) + expected = expected_world.cpu().index_select(scatter_dim, slice_idx) + assert res.cpu().allclose(expected) + + xm.rendezvous( + 'test_reduce_scatter_list_input_output_bucketized, zero bucket size') else: print( 'Default device {} is not a TPU device'.format(device), file=sys.stderr) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index dc549db45bf..d93b26687ea 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -635,12 +635,7 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): class CoalescingBuckets(object): - def __init__(self, - func, - input_list, - output_list=None, - groups=None, - bucket_cap_mb=160): + def __init__(self, func, input_list, output_list=None, bucket_cap_mb=160): if not isinstance(input_list, list) or any( not isinstance(v, torch.Tensor) for v in input_list): raise TypeError( @@ -663,15 +658,14 @@ def __init__(self, self._tensor_bucket = [] self._output_bucket = [] if output_list else None self._bucket_cap = bucket_cap_mb * 1024 * 1024 - if groups: - divisor = len(groups[0]) if type(groups[0]) == list else len(groups) - else: - divisor = xrt_world_size() - self._bucket_cap = self._bucket_cap / divisor self._out_tensors = [] def flush(self): - if len(self._tensor_bucket): + if len(self._tensor_bucket) == 1: + # Use non-coalesced CCOp if its just one tensor + output = self._output_bucket[0] if self._output_bucket else None + self._out_tensors.append(self._func(self._tensor_bucket[0], output)) + elif len(self._tensor_bucket): self._out_tensors.extend( self._func(self._tensor_bucket, self._output_bucket)) self._total = 0 @@ -712,7 +706,7 @@ def all_gather_bucketized(input_list, dim=0, groups=None, output=None, - pin_layout=True, + pin_layout=False, bucket_cap_mb=160): """Performs an all-gather operation along a given dimension, with bucketization. @@ -739,11 +733,7 @@ def _all_gather_coalesced(_input_list, _output_list=None): pin_layout=pin_layout) buckets = CoalescingBuckets( - _all_gather_coalesced, - input_list, - output, - groups=groups, - bucket_cap_mb=bucket_cap_mb) + _all_gather_coalesced, input_list, output, bucket_cap_mb=bucket_cap_mb) return buckets() @@ -967,7 +957,7 @@ def reduce_scatter_bucketized(reduce_type, shard_count, groups=None, output=None, - pin_layout=True, + pin_layout=False, bucket_cap_mb=160): """Performs a XLA `ReduceScatter()` operation on a list of tensors (bucketized). @@ -1000,7 +990,6 @@ def _reduce_scatter_coalesced(_input_list, _output_list=None): _reduce_scatter_coalesced, input_list, output, - groups=groups, bucket_cap_mb=bucket_cap_mb) return buckets() diff --git a/torch_xla/distributed/zero_redundancy_optimizer.py b/torch_xla/distributed/zero_redundancy_optimizer.py index 03c16e3e5bd..9b21fe4ead8 100644 --- a/torch_xla/distributed/zero_redundancy_optimizer.py +++ b/torch_xla/distributed/zero_redundancy_optimizer.py @@ -63,7 +63,8 @@ def __init__( sharding_groups: Optional[Any] = None, grad_norm_groups: Optional[Any] = None, lazy_init: bool = False, - bucket_cap_mb: int = 0, + bucket_cap_mb_all_gather: int = 0, + bucket_cap_mb_reduce_scatter: int = 0, **defaults: Any, ): super().__init__(params, defaults) @@ -80,8 +81,10 @@ def __init__( self.grad_clipping = grad_clipping self.max_norm = max_norm if max_norm is not None else 1.0 self.pin_layout = pin_layout - self.bucket_cap_mb = bucket_cap_mb - self.coalesce_cc = bucket_cap_mb > 0 + self.bucket_cap_mb_all_gather = bucket_cap_mb_all_gather + self.bucket_cap_mb_reduce_scatter = bucket_cap_mb_reduce_scatter + self.coalesce_cc_all_gather = bucket_cap_mb_all_gather > 0 + self.coalesce_cc_reduce_scatter = bucket_cap_mb_reduce_scatter > 0 self._grad_norm = None @@ -282,7 +285,7 @@ def step(self, closure=None, **kwargs): if param.grad is not None: padded_grad = self._pad_to_world_size(param.grad, self.local_world_size) - if self.coalesce_cc: + if self.coalesce_cc_reduce_scatter: padded_grads.append(padded_grad) else: grad_shard = xm.reduce_scatter( @@ -298,7 +301,7 @@ def step(self, closure=None, **kwargs): grad_shard = grad_shard.to(dtype=self.optimizer_dtype) shard.grad = grad_shard - if self.coalesce_cc: + if self.coalesce_cc_reduce_scatter: grad_shards = xm.reduce_scatter_bucketized( xm.REDUCE_SUM, padded_grads, @@ -307,7 +310,7 @@ def step(self, closure=None, **kwargs): shard_count=self.local_world_size, pin_layout=self.pin_layout, groups=self.sharding_groups, - bucket_cap_mb=self.bucket_cap_mb, + bucket_cap_mb=self.bucket_cap_mb_reduce_scatter, ) index = 0 for param_group, sharded_param_group in zip( @@ -341,7 +344,7 @@ def step(self, closure=None, **kwargs): shard_data = shard.data if param.dtype != self.optimizer_dtype: shard_data = shard_data.to(dtype=param.dtype) - if self.coalesce_cc: + if self.coalesce_cc_all_gather: sharded_data.append(shard_data) else: padded_param = xm.all_gather( @@ -352,13 +355,13 @@ def step(self, closure=None, **kwargs): ) param.data.copy_(padded_param.data[:param.size(0)]) - if self.coalesce_cc: + if self.coalesce_cc_all_gather: padded_params = xm.all_gather_bucketized( sharded_data, dim=0, pin_layout=self.pin_layout, groups=self.sharding_groups, - bucket_cap_mb=self.bucket_cap_mb, + bucket_cap_mb=self.bucket_cap_mb_all_gather, ) index = 0 for param_group, sharded_param_group in zip(