Skip to content

Commit

Permalink
Fixing bug
Browse files Browse the repository at this point in the history
  • Loading branch information
amithrm committed Feb 2, 2024
1 parent 0aee211 commit 5a61ef1
Showing 1 changed file with 84 additions and 25 deletions.
109 changes: 84 additions & 25 deletions torch_xla/distributed/zero_redundancy_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import copy
import inspect
import os
from typing import (Any, Iterator, Optional, Type, Union, List, Dict)

import torch
Expand Down Expand Up @@ -77,6 +79,8 @@ def __init__(
self.max_norm = max_norm if max_norm is not None else 1.0
self.pin_layout = pin_layout

self._grad_norm = None

self.inited = False
if not lazy_init:
self.init_zero()
Expand All @@ -91,17 +95,36 @@ def init_zero(self):
group = list(group)
self.local_rank = group.index(self.global_rank)
if self.local_rank is None:
raise ValueError(
f"Current rank {self.global_rank} is missing from the sharding_groups {self.sharding_groups}"
)
raise ValueError(f"Current rank {self.global_rank} is missing from the sharding_groups {self.sharding_groups}")
# Shard parameters for use in optimizer
sharded_param_groups = self._shard_parameters()
# Optimizer initialization
# Here we pop the differentiable default because the adam family of
# optimizers don't have differentiable as an argument. This should
# be fixed by this commit https://github.com/pytorch/pytorch/pull/86183
# and should be available in torch==2.0. For 1.13, we are patching it here.
# When we do a re-init after loading weights, the defaults would be set
# by optimizer base class which would break the adamw, adam initialization.
# Hence, we pop the argument if the optimizer class doesn't accept one.
# Assumption: If the optimizer class didn't have one, the base class added it
# when loading state dict.
func_args = inspect.signature(self.optimizer_class.__init__)
if "differentiable" not in func_args.parameters:
before_differentiable = self.defaults.pop("differentiable", None)
self.base_optimizer = self.optimizer_class(sharded_param_groups,
**self.defaults)
differentiable_value = getattr(self.base_optimizer, "differentiable", None)
if differentiable_value is not None:
assert before_differentiable == differentiable_value, \
"differentiable argument changes value after initialization"
self.defaults["differentiable"] = differentiable_value
self._sync_param_groups(self.param_groups, self.base_optimizer.param_groups)
self.inited = True

@property
def grad_norm(self):
return self._grad_norm

@property
def sharding_groups(self):
return self._sharding_groups
Expand Down Expand Up @@ -158,12 +181,17 @@ def _shard_parameters(self):
"""
Shard all parameters.
"""
self.device = None
all_params = []
for param_group in self.param_groups:
for param in param_group['params']:
all_params.append(param)
if self.device is None:
self.device = param.device
else:
assert self.device == param.device, "Params should on the same device."
assert self.device.type == 'xla'

self.device = all_params[0].device
xm.unlazy(all_params)

sharded_params_groups = []
Expand Down Expand Up @@ -227,13 +255,13 @@ def _clip_grad_norm(
"""
max_norm = float(max_norm)
norm_type = float(norm_type)
total_norm = self._calc_grad_norm(norm_type)
self._grad_norm = self._calc_grad_norm(norm_type)

clip_coeff = torch.tensor(
max_norm, device=self.device) / (
total_norm + 1e-6)
clip_value = torch.where(clip_coeff < 1, clip_coeff,
torch.tensor(1., device=self.device))
max_norm, device=self.device, dtype=self.optimizer_dtype) / (
self._grad_norm + 1e-6)
clip_value = torch.where(clip_coeff < 1, clip_coeff,
torch.tensor(1., device=self.device, dtype=self.optimizer_dtype))
for param_group in self.base_optimizer.param_groups:
for p in param_group['params']:
if p.grad is not None:
Expand All @@ -256,26 +284,36 @@ def step(self, closure=None, **kwargs):

# Reduce full gradients across ranks
# Assign gradient shards to the respective parameter shards
padded_grads = []
for param_group, sharded_param_group in zip(
self.param_groups, self.base_optimizer.param_groups):
for param, shard in zip(param_group['params'],
sharded_param_group['params']):
if param.grad is not None:
padded_grad = self._pad_to_world_size(param.grad,
self.local_world_size)
grad_shard = xm.reduce_scatter(
xm.REDUCE_SUM,
padded_grad,
scale=1.0 / self.local_world_size,
scatter_dim=0,
shard_count=self.local_world_size,
pin_layout=self.pin_layout,
groups=self.sharding_groups,
)

padded_grads.append(padded_grad)

grad_shards = xm.reduce_scatter(
xm.REDUCE_SUM,
padded_grads,
scale=1.0 / self.local_world_size,
scatter_dim=0,
shard_count=self.local_world_size,
pin_layout=self.pin_layout,
groups=self.sharding_groups,
)
index = 0
for param_group, sharded_param_group in zip(
self.param_groups, self.base_optimizer.param_groups):
for param, shard in zip(param_group['params'],
sharded_param_group['params']):
if param.grad is not None:
grad_shard = grad_shards[index]
if grad_shard.dtype != self.optimizer_dtype:
grad_shard = grad_shard.to(dtype=self.optimizer_dtype)
shard.grad = grad_shard
index += 1

if self.grad_clipping:
# Update unscale/clip with sub partitions
Expand All @@ -288,6 +326,7 @@ def step(self, closure=None, **kwargs):
self.base_optimizer.zero_grad(set_to_none=True)

# All gather the new weights across the ranks and assign them to the full parameters
sharded_data = []
for param_group, sharded_param_group in zip(
self.param_groups, self.base_optimizer.param_groups):
for param, shard in zip(param_group['params'],
Expand All @@ -296,13 +335,23 @@ def step(self, closure=None, **kwargs):
shard_data = shard.data
if param.dtype != self.optimizer_dtype:
shard_data = shard_data.to(dtype=param.dtype)
padded_param = xm.all_gather(
shard_data,
dim=0,
pin_layout=self.pin_layout,
groups=self.sharding_groups,
)
sharded_data.append(shard_data)

padded_params = xm.all_gather(
sharded_data,
dim=0,
pin_layout=self.pin_layout,
groups=self.sharding_groups,
)
index = 0
for param_group, sharded_param_group in zip(
self.param_groups, self.base_optimizer.param_groups):
for param, shard in zip(param_group['params'],
sharded_param_group['params']):
if param.grad is not None:
padded_param = padded_params[index]
param.data.copy_(padded_param.data[:param.size(0)])
index += 1

# sync back
self._sync_param_groups(self.base_optimizer.param_groups, self.param_groups)
Expand All @@ -313,6 +362,7 @@ def state_dict(self):
state_dict = super().state_dict()
base_state = self.base_optimizer.state_dict()['state']
state_dict['base_state'] = base_state
state_dict['shape_info'] = self.get_shape_info()
return state_dict

def load_state_dict(self, state_dict):
Expand All @@ -326,3 +376,12 @@ def load_state_dict(self, state_dict):
tmp = self.base_optimizer.state_dict()
tmp['state'] = base_state
self.base_optimizer.load_state_dict(tmp)

def get_shape_info(self):
shape_info = {}
idx = 0
for param_group in self.param_groups:
for param in param_group['params']:
shape_info[idx] = param.shape
idx += 1
return shape_info

0 comments on commit 5a61ef1

Please sign in to comment.