diff --git a/test/test_mp_all_gather.py b/test/test_mp_all_gather.py index 5d38158d287..14f46043175 100644 --- a/test/test_mp_all_gather.py +++ b/test/test_mp_all_gather.py @@ -95,6 +95,66 @@ def _mp_fn(index): file=sys.stderr) print(f'[{index}] {cpu_result}', file=sys.stderr) sys.exit(1) + + # Testing with a single replica group and tensor list as input (Bucketized) + # TODO: add support for list input with pin_layout=True and output=None + result_list = xm.all_gather_bucketized( + ordinal_tensors, dim=0, pin_layout=False) + + for i, result in enumerate(result_list): + cpu_result = result.cpu() + expected = i * 1000 + torch.arange(world_size, dtype=torch.float) + if not cpu_result.allclose(expected): + print( + 'xm.all_gather_bucketized() produced wrong reductions for item {i} in result list', + file=sys.stderr) + print(f'[{index}] {cpu_result}', file=sys.stderr) + sys.exit(1) + + # Testing with a single replica group and tensor list as input and output!=None (out-of-place) (Bucketized) + # Reuse ordinal_tensors from previous test + output_tensors = [ + torch.zeros([world_size], dtype=torch.float).to(device) + for i in range(input_list_size) + ] + # TODO: add support for list input with pin_layout=True and output!=None + result_list = xm.all_gather_bucketized( + ordinal_tensors, dim=0, output=output_tensors, pin_layout=False) + + for i, result in enumerate(result_list): + cpu_result = result.cpu() + expected = i * 1000 + torch.arange(world_size, dtype=torch.float) + if not cpu_result.allclose(expected): + print( + 'xm.all_gather() produced wrong reductions for item {i} in result list', + file=sys.stderr) + print(f'[{index}] {cpu_result}', file=sys.stderr) + sys.exit(1) + + # Testing with a single replica group and tensor list as input and output!=None (out-of-place) (Bucketized, zero bucket size) + # Reuse ordinal_tensors from previous test + output_tensors = [ + torch.zeros([world_size], dtype=torch.float).to(device) + for i in range(input_list_size) + ] + # TODO: add support for list input with pin_layout=True and output!=None + result_list = xm.all_gather_bucketized( + ordinal_tensors, + dim=0, + output=output_tensors, + pin_layout=False, + bucket_cap_mb=0) + + for i, result in enumerate(result_list): + cpu_result = result.cpu() + expected = i * 1000 + torch.arange(world_size, dtype=torch.float) + if not cpu_result.allclose(expected): + print( + 'xm.all_gather() produced wrong reductions for item {i} in result list', + file=sys.stderr) + print(f'[{index}] {cpu_result}', file=sys.stderr) + sys.exit(1) + # TODO: add test for torch.compile when support for list input is ready else: diff --git a/test/test_mp_reduce_scatter.py b/test/test_mp_reduce_scatter.py index 1ef61d3aa79..07173926091 100644 --- a/test/test_mp_reduce_scatter.py +++ b/test/test_mp_reduce_scatter.py @@ -55,6 +55,33 @@ def _mp_fn(index): xm.rendezvous('test_reduce_scatter_list_input') + # Testing reduce-scatter with list input bucketized + rand_list = [ + torch.rand((32, shard_size * world_size, 32)) + for _ in range(input_list_size) + ] + xrand_list = [rand.to(device) for rand in rand_list] + + # TODO: fix the broken case with pin_layout=True + res_list = xm.reduce_scatter_bucketized( + xm.REDUCE_SUM, + xrand_list, + scale, + scatter_dim, + world_size, + pin_layout=False) + + for i, res in enumerate(res_list): + expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale) + xm.mark_step() + + slice_idx = torch.tensor( + list(range(index * shard_size, (index + 1) * shard_size))) + expected = expected_world.cpu().index_select(scatter_dim, slice_idx) + assert res.cpu().allclose(expected) + + xm.rendezvous('test_reduce_scatter_list_input_bucketized') + # Testing reduce-scatter with list input and output output_list = [ torch.rand((32, shard_size * world_size, 32)) @@ -83,6 +110,66 @@ def _mp_fn(index): assert res.cpu().allclose(expected) xm.rendezvous('test_reduce_scatter_list_input_output') + + # Testing reduce-scatter with list input and output (buckettized) + output_list = [ + torch.rand((32, shard_size * world_size, 32)) + for _ in range(input_list_size) + ] + xoutput_list = [output.to(device) for output in output_list] + + # TODO: fix the broken case with pin_layout=True + res_list = xm.reduce_scatter_bucketized( + xm.REDUCE_SUM, + xrand_list, + scale, + scatter_dim, + world_size, + output=xoutput_list, + pin_layout=False) + + assert (xoutput_list == res_list) + for i, res in enumerate(xoutput_list): + expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale) + xm.mark_step() + + slice_idx = torch.tensor( + list(range(index * shard_size, (index + 1) * shard_size))) + expected = expected_world.cpu().index_select(scatter_dim, slice_idx) + assert res.cpu().allclose(expected) + + xm.rendezvous('test_reduce_scatter_list_input_output_bucketized') + + # Testing reduce-scatter with list input and output (buckettized, but zero bucket size) + output_list = [ + torch.rand((32, shard_size * world_size, 32)) + for _ in range(input_list_size) + ] + xoutput_list = [output.to(device) for output in output_list] + + # TODO: fix the broken case with pin_layout=True + res_list = xm.reduce_scatter_bucketized( + xm.REDUCE_SUM, + xrand_list, + scale, + scatter_dim, + world_size, + output=xoutput_list, + bucket_cap_mb=0, + pin_layout=False) + + assert (xoutput_list == res_list) + for i, res in enumerate(xoutput_list): + expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale) + xm.mark_step() + + slice_idx = torch.tensor( + list(range(index * shard_size, (index + 1) * shard_size))) + expected = expected_world.cpu().index_select(scatter_dim, slice_idx) + assert res.cpu().allclose(expected) + + xm.rendezvous( + 'test_reduce_scatter_list_input_output_bucketized, zero bucket size') else: print( 'Default device {} is not a TPU device'.format(device), file=sys.stderr) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py old mode 100755 new mode 100644 index 9b2cd139e88..a5b7e6693ab --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -634,6 +634,110 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): f"given {type(value)}.") +class CoalescingBuckets(object): + + def __init__(self, func, input_list, output_list=None, bucket_cap_mb=160): + if not isinstance(input_list, list) or any( + not isinstance(v, torch.Tensor) for v in input_list): + raise TypeError( + f"`input_list` needs to be a list of Tensors, but given {type(input_list)}." + ) + if output_list != None: + if not isinstance(output_list, list) or any( + not isinstance(v, torch.Tensor) for v in output_list): + raise TypeError( + f"`output_list` needs to be a list of Tensors, but given {type(output_list)}." + ) + if len(output_list) != len(input_list): + raise ValueError( + "`output_list` length doesn't match `input_list` length: " + f"{len(output_list)} vs {len(input_list)}.") + self._func = func + self._input_list = input_list + self._output_list = output_list + self._total = 0 + self._tensor_bucket = [] + self._output_bucket = [] if output_list else None + self._bucket_cap = bucket_cap_mb * 1024 * 1024 + self._out_tensors = [] + + def flush(self): + if len(self._tensor_bucket) == 1: + # Use non-coalesced CCOp if its just one tensor + output = self._output_bucket[0] if self._output_bucket else None + self._out_tensors.append(self._func(self._tensor_bucket[0], output)) + elif len(self._tensor_bucket): + self._out_tensors.extend( + self._func(self._tensor_bucket, self._output_bucket)) + self._total = 0 + self._tensor_bucket = [] + self._output_bucket = [] if self._output_list else None + + def add(self, tensor, idx): + self._total += tensor.numel() * tensor.element_size() + self._tensor_bucket.append(tensor) + if self._output_list != None: + self._output_bucket.append(self._output_list[idx]) + + def __call__(self): + for idx, tensor in enumerate(self._input_list): + tensor_bytes = tensor.numel() * tensor.element_size() + + # Aim for target bucket_cap_mb: flush new tensor with bucket if bucket content + # is small (1/2 cap) but don't combine if combined total is over 2x cap + total_new = self._total + tensor_bytes + if tensor_bytes > self._bucket_cap and self._total < 0.5 * self._bucket_cap and total_new <= 2 * self._bucket_cap: + self.add(tensor, idx) + self.flush() + else: + # Bucketize till the total spills over + if total_new > self._bucket_cap: + self.flush() + self.add(tensor, idx) + + # Flush the last remaining bucket + self.flush() + + assert len(self._out_tensors) == len(self._input_list) + + return self._out_tensors + + +def all_gather_bucketized(input_list, + dim=0, + groups=None, + output=None, + pin_layout=False, + bucket_cap_mb=160): + """Performs an all-gather operation along a given dimension, with bucketization. + + Args: + See all_gather for the args: dim, groups, output, pin_layout + input_list: List of input tensors + bucket_cap_mb: Number of MegaBytes of the tensor bucket to fill before doing all-gather. + + Returns: + A list of tensors each of which has, in the ``dim`` dimension, all the values from the + participating replicas. + """ + # sanity checks + if pin_layout: + raise RuntimeError( + "For xm.all_gather_bucketized, pin_layout=True is not yet supported.") + + def _all_gather_coalesced(_input_list, _output_list=None): + return all_gather( + value=_input_list, + dim=dim, + groups=groups, + output=_output_list, + pin_layout=pin_layout) + + buckets = CoalescingBuckets( + _all_gather_coalesced, input_list, output, bucket_cap_mb=bucket_cap_mb) + return buckets() + + def all_to_all(value, split_dimension, concat_dimension, @@ -847,6 +951,50 @@ def reduce_scatter(reduce_type, f"given {type(input)}.") +def reduce_scatter_bucketized(reduce_type, + input_list, + scale, + scatter_dim, + shard_count, + groups=None, + output=None, + pin_layout=False, + bucket_cap_mb=160): + """Performs a XLA `ReduceScatter()` operation on a list of tensors (bucketized). + + See: https://www.tensorflow.org/xla/operation_semantics#reducescatter + + Args: + see reduce_scatter for reduce_type, scale, scatter_dim, shard_count, groups, pin_layout + input_list: List of input tensors + output: Optional list of output torch.Tensor + bucket_cap_mb: Number of MegaBytes of the tensor bucket to fill before doing all-gather. + + Returns: + A list of `torch.Tensors` with all the values reduced across replicas. Each process + gets a shard split along the `scatter_dim`. All other dimensions are + the same as the input. + """ + + def _reduce_scatter_coalesced(_input_list, _output_list=None): + return reduce_scatter( + reduce_type=reduce_type, + input=_input_list, + scale=scale, + scatter_dim=scatter_dim, + shard_count=shard_count, + groups=groups, + output=_output_list, + pin_layout=pin_layout) + + buckets = CoalescingBuckets( + _reduce_scatter_coalesced, + input_list, + output, + bucket_cap_mb=bucket_cap_mb) + return buckets() + + def add_step_closure(closure, args=(), run_async=False): """Adds a closure to the list of the ones to be run at the end of the step. diff --git a/torch_xla/distributed/zero_redundancy_optimizer.py b/torch_xla/distributed/zero_redundancy_optimizer.py index f00929eeb86..9b21fe4ead8 100644 --- a/torch_xla/distributed/zero_redundancy_optimizer.py +++ b/torch_xla/distributed/zero_redundancy_optimizer.py @@ -40,6 +40,9 @@ class ZeroRedundancyOptimizer(Optimizer): If specified, ZeRO-1 will use this ``grad_norm_groups`` for the EXTRA all-reduce op in grad norm calculation. This can be model parallel groups when mixing ZeRO-1 with model parallelism such as Megatron. + bucket_cap_mb: + If non-zero, specifies the maximum number of megabytes to combine tensors + before doing the all-gather/reduce-scatter operations. **defaults: any trailing arguments, which are forwarded to the local optimizer. @@ -60,6 +63,8 @@ def __init__( sharding_groups: Optional[Any] = None, grad_norm_groups: Optional[Any] = None, lazy_init: bool = False, + bucket_cap_mb_all_gather: int = 0, + bucket_cap_mb_reduce_scatter: int = 0, **defaults: Any, ): super().__init__(params, defaults) @@ -76,6 +81,12 @@ def __init__( self.grad_clipping = grad_clipping self.max_norm = max_norm if max_norm is not None else 1.0 self.pin_layout = pin_layout + self.bucket_cap_mb_all_gather = bucket_cap_mb_all_gather + self.bucket_cap_mb_reduce_scatter = bucket_cap_mb_reduce_scatter + self.coalesce_cc_all_gather = bucket_cap_mb_all_gather > 0 + self.coalesce_cc_reduce_scatter = bucket_cap_mb_reduce_scatter > 0 + + self._grad_norm = None self.inited = False if not lazy_init: @@ -102,6 +113,10 @@ def init_zero(self): self._sync_param_groups(self.param_groups, self.base_optimizer.param_groups) self.inited = True + @property + def grad_norm(self): + return self._grad_norm + @property def sharding_groups(self): return self._sharding_groups @@ -158,12 +173,17 @@ def _shard_parameters(self): """ Shard all parameters. """ + self.device = None all_params = [] for param_group in self.param_groups: for param in param_group['params']: all_params.append(param) + if self.device is None: + self.device = param.device + else: + assert self.device == param.device, "Params should on the same device." + assert self.device.type == 'xla' - self.device = all_params[0].device xm.unlazy(all_params) sharded_params_groups = [] @@ -227,13 +247,14 @@ def _clip_grad_norm( """ max_norm = float(max_norm) norm_type = float(norm_type) - total_norm = self._calc_grad_norm(norm_type) + self._grad_norm = self._calc_grad_norm(norm_type) clip_coeff = torch.tensor( - max_norm, device=self.device) / ( - total_norm + 1e-6) - clip_value = torch.where(clip_coeff < 1, clip_coeff, - torch.tensor(1., device=self.device)) + max_norm, device=self.device, dtype=self.optimizer_dtype) / ( + self._grad_norm + 1e-6) + clip_value = torch.where( + clip_coeff < 1, clip_coeff, + torch.tensor(1., device=self.device, dtype=self.optimizer_dtype)) for param_group in self.base_optimizer.param_groups: for p in param_group['params']: if p.grad is not None: @@ -256,6 +277,7 @@ def step(self, closure=None, **kwargs): # Reduce full gradients across ranks # Assign gradient shards to the respective parameter shards + padded_grads = [] for param_group, sharded_param_group in zip( self.param_groups, self.base_optimizer.param_groups): for param, shard in zip(param_group['params'], @@ -263,19 +285,44 @@ def step(self, closure=None, **kwargs): if param.grad is not None: padded_grad = self._pad_to_world_size(param.grad, self.local_world_size) - grad_shard = xm.reduce_scatter( - xm.REDUCE_SUM, - padded_grad, - scale=1.0 / self.local_world_size, - scatter_dim=0, - shard_count=self.local_world_size, - pin_layout=self.pin_layout, - groups=self.sharding_groups, - ) - - if grad_shard.dtype != self.optimizer_dtype: - grad_shard = grad_shard.to(dtype=self.optimizer_dtype) - shard.grad = grad_shard + if self.coalesce_cc_reduce_scatter: + padded_grads.append(padded_grad) + else: + grad_shard = xm.reduce_scatter( + xm.REDUCE_SUM, + padded_grad, + scale=1.0 / self.local_world_size, + scatter_dim=0, + shard_count=self.local_world_size, + pin_layout=self.pin_layout, + groups=self.sharding_groups, + ) + if grad_shard.dtype != self.optimizer_dtype: + grad_shard = grad_shard.to(dtype=self.optimizer_dtype) + shard.grad = grad_shard + + if self.coalesce_cc_reduce_scatter: + grad_shards = xm.reduce_scatter_bucketized( + xm.REDUCE_SUM, + padded_grads, + scale=1.0 / self.local_world_size, + scatter_dim=0, + shard_count=self.local_world_size, + pin_layout=self.pin_layout, + groups=self.sharding_groups, + bucket_cap_mb=self.bucket_cap_mb_reduce_scatter, + ) + index = 0 + for param_group, sharded_param_group in zip( + self.param_groups, self.base_optimizer.param_groups): + for param, shard in zip(param_group['params'], + sharded_param_group['params']): + if param.grad is not None: + grad_shard = grad_shards[index] + if grad_shard.dtype != self.optimizer_dtype: + grad_shard = grad_shard.to(dtype=self.optimizer_dtype) + shard.grad = grad_shard + index += 1 if self.grad_clipping: # Update unscale/clip with sub partitions @@ -288,6 +335,7 @@ def step(self, closure=None, **kwargs): self.base_optimizer.zero_grad(set_to_none=True) # All gather the new weights across the ranks and assign them to the full parameters + sharded_data = [] for param_group, sharded_param_group in zip( self.param_groups, self.base_optimizer.param_groups): for param, shard in zip(param_group['params'], @@ -296,13 +344,34 @@ def step(self, closure=None, **kwargs): shard_data = shard.data if param.dtype != self.optimizer_dtype: shard_data = shard_data.to(dtype=param.dtype) - padded_param = xm.all_gather( - shard_data, - dim=0, - pin_layout=self.pin_layout, - groups=self.sharding_groups, - ) - param.data.copy_(padded_param.data[:param.size(0)]) + if self.coalesce_cc_all_gather: + sharded_data.append(shard_data) + else: + padded_param = xm.all_gather( + shard_data, + dim=0, + pin_layout=self.pin_layout, + groups=self.sharding_groups, + ) + param.data.copy_(padded_param.data[:param.size(0)]) + + if self.coalesce_cc_all_gather: + padded_params = xm.all_gather_bucketized( + sharded_data, + dim=0, + pin_layout=self.pin_layout, + groups=self.sharding_groups, + bucket_cap_mb=self.bucket_cap_mb_all_gather, + ) + index = 0 + for param_group, sharded_param_group in zip( + self.param_groups, self.base_optimizer.param_groups): + for param, shard in zip(param_group['params'], + sharded_param_group['params']): + if param.grad is not None: + padded_param = padded_params[index] + param.data.copy_(padded_param.data[:param.size(0)]) + index += 1 # sync back self._sync_param_groups(self.base_optimizer.param_groups, self.param_groups) @@ -313,6 +382,7 @@ def state_dict(self): state_dict = super().state_dict() base_state = self.base_optimizer.state_dict()['state'] state_dict['base_state'] = base_state + state_dict['shape_info'] = self.get_shape_info() return state_dict def load_state_dict(self, state_dict): @@ -326,3 +396,12 @@ def load_state_dict(self, state_dict): tmp = self.base_optimizer.state_dict() tmp['state'] = base_state self.base_optimizer.load_state_dict(tmp) + + def get_shape_info(self): + shape_info = {} + idx = 0 + for param_group in self.param_groups: + for param in param_group['params']: + shape_info[idx] = param.shape + idx += 1 + return shape_info