diff --git a/torch_xla/distributed/zero_redundancy_optimizer.py b/torch_xla/distributed/zero_redundancy_optimizer.py index f00929eeb86b..cd6b72c9d80e 100644 --- a/torch_xla/distributed/zero_redundancy_optimizer.py +++ b/torch_xla/distributed/zero_redundancy_optimizer.py @@ -1,4 +1,6 @@ import copy +import inspect +import os from typing import (Any, Iterator, Optional, Type, Union, List, Dict) import torch @@ -77,6 +79,8 @@ def __init__( self.max_norm = max_norm if max_norm is not None else 1.0 self.pin_layout = pin_layout + self._grad_norm = None + self.inited = False if not lazy_init: self.init_zero() @@ -91,17 +95,36 @@ def init_zero(self): group = list(group) self.local_rank = group.index(self.global_rank) if self.local_rank is None: - raise ValueError( - f"Current rank {self.global_rank} is missing from the sharding_groups {self.sharding_groups}" - ) + raise ValueError(f"Current rank {self.global_rank} is missing from the sharding_groups {self.sharding_groups}") # Shard parameters for use in optimizer sharded_param_groups = self._shard_parameters() # Optimizer initialization + # Here we pop the differentiable default because the adam family of + # optimizers don't have differentiable as an argument. This should + # be fixed by this commit https://github.com/pytorch/pytorch/pull/86183 + # and should be available in torch==2.0. For 1.13, we are patching it here. + # When we do a re-init after loading weights, the defaults would be set + # by optimizer base class which would break the adamw, adam initialization. + # Hence, we pop the argument if the optimizer class doesn't accept one. + # Assumption: If the optimizer class didn't have one, the base class added it + # when loading state dict. + func_args = inspect.signature(self.optimizer_class.__init__) + if "differentiable" not in func_args.parameters: + before_differentiable = self.defaults.pop("differentiable", None) self.base_optimizer = self.optimizer_class(sharded_param_groups, **self.defaults) + differentiable_value = getattr(self.base_optimizer, "differentiable", None) + if differentiable_value is not None: + assert before_differentiable == differentiable_value, \ + "differentiable argument changes value after initialization" + self.defaults["differentiable"] = differentiable_value self._sync_param_groups(self.param_groups, self.base_optimizer.param_groups) self.inited = True + @property + def grad_norm(self): + return self._grad_norm + @property def sharding_groups(self): return self._sharding_groups @@ -158,12 +181,17 @@ def _shard_parameters(self): """ Shard all parameters. """ + self.device = None all_params = [] for param_group in self.param_groups: for param in param_group['params']: all_params.append(param) + if self.device is None: + self.device = param.device + else: + assert self.device == param.device, "Params should on the same device." + assert self.device.type == 'xla' - self.device = all_params[0].device xm.unlazy(all_params) sharded_params_groups = [] @@ -227,13 +255,13 @@ def _clip_grad_norm( """ max_norm = float(max_norm) norm_type = float(norm_type) - total_norm = self._calc_grad_norm(norm_type) + self._grad_norm = self._calc_grad_norm(norm_type) clip_coeff = torch.tensor( - max_norm, device=self.device) / ( - total_norm + 1e-6) - clip_value = torch.where(clip_coeff < 1, clip_coeff, - torch.tensor(1., device=self.device)) + max_norm, device=self.device, dtype=self.optimizer_dtype) / ( + self._grad_norm + 1e-6) + clip_value = torch.where(clip_coeff < 1, clip_coeff, + torch.tensor(1., device=self.device, dtype=self.optimizer_dtype)) for param_group in self.base_optimizer.param_groups: for p in param_group['params']: if p.grad is not None: @@ -256,6 +284,7 @@ def step(self, closure=None, **kwargs): # Reduce full gradients across ranks # Assign gradient shards to the respective parameter shards + padded_grads = [] for param_group, sharded_param_group in zip( self.param_groups, self.base_optimizer.param_groups): for param, shard in zip(param_group['params'], @@ -263,19 +292,28 @@ def step(self, closure=None, **kwargs): if param.grad is not None: padded_grad = self._pad_to_world_size(param.grad, self.local_world_size) - grad_shard = xm.reduce_scatter( - xm.REDUCE_SUM, - padded_grad, - scale=1.0 / self.local_world_size, - scatter_dim=0, - shard_count=self.local_world_size, - pin_layout=self.pin_layout, - groups=self.sharding_groups, - ) - + padded_grads.append(padded_grad) + + grad_shards = xm.reduce_scatter( + xm.REDUCE_SUM, + padded_grads, + scale=1.0 / self.local_world_size, + scatter_dim=0, + shard_count=self.local_world_size, + pin_layout=self.pin_layout, + groups=self.sharding_groups, + ) + index = 0 + for param_group, sharded_param_group in zip( + self.param_groups, self.base_optimizer.param_groups): + for param, shard in zip(param_group['params'], + sharded_param_group['params']): + if param.grad is not None: + grad_shard = grad_shards[index] if grad_shard.dtype != self.optimizer_dtype: grad_shard = grad_shard.to(dtype=self.optimizer_dtype) shard.grad = grad_shard + index += 1 if self.grad_clipping: # Update unscale/clip with sub partitions @@ -288,6 +326,7 @@ def step(self, closure=None, **kwargs): self.base_optimizer.zero_grad(set_to_none=True) # All gather the new weights across the ranks and assign them to the full parameters + sharded_data = [] for param_group, sharded_param_group in zip( self.param_groups, self.base_optimizer.param_groups): for param, shard in zip(param_group['params'], @@ -296,13 +335,23 @@ def step(self, closure=None, **kwargs): shard_data = shard.data if param.dtype != self.optimizer_dtype: shard_data = shard_data.to(dtype=param.dtype) - padded_param = xm.all_gather( - shard_data, - dim=0, - pin_layout=self.pin_layout, - groups=self.sharding_groups, - ) + sharded_data.append(shard_data) + + padded_params = xm.all_gather( + sharded_data, + dim=0, + pin_layout=self.pin_layout, + groups=self.sharding_groups, + ) + index = 0 + for param_group, sharded_param_group in zip( + self.param_groups, self.base_optimizer.param_groups): + for param, shard in zip(param_group['params'], + sharded_param_group['params']): + if param.grad is not None: + padded_param = padded_params[index] param.data.copy_(padded_param.data[:param.size(0)]) + index += 1 # sync back self._sync_param_groups(self.base_optimizer.param_groups, self.param_groups) @@ -313,6 +362,7 @@ def state_dict(self): state_dict = super().state_dict() base_state = self.base_optimizer.state_dict()['state'] state_dict['base_state'] = base_state + state_dict['shape_info'] = self.get_shape_info() return state_dict def load_state_dict(self, state_dict): @@ -326,3 +376,12 @@ def load_state_dict(self, state_dict): tmp = self.base_optimizer.state_dict() tmp['state'] = base_state self.base_optimizer.load_state_dict(tmp) + + def get_shape_info(self): + shape_info = {} + idx = 0 + for param_group in self.param_groups: + for param in param_group['params']: + shape_info[idx] = param.shape + idx += 1 + return shape_info