Skip to content

Commit

Permalink
ZeRO3, improved parameter all-gather operation (#1188)
Browse files Browse the repository at this point in the history
* remove norm(), avoid memcpy after allgather

1) Removing the norm computation in debug printing
2) Changing _all_gather to be sync op in fetch_sub_module
    Reason: the async version is not async at all, because each
    all_gather calls torch.cuda.synchronize() to guarantee previous
    communication op to be completed
3) Adding new function _allgather_params_split_launch
    the existing _allgather_params has explicit memcpy after the
    all-gather op. We can avoid the explicit memory copy at
    python side, to improve the performance.

Known issue:
    the `torch.distributed.all_gather` will do implicit memcpy
    at the end of each `ncclAllgather`.

* WIP: wrapped ncclAllgather as customized op in DS

micro benchmark shows the improvement of allgather a
transformer layer with 9834560 elements in half precision is about
1.1ms on aws-p4d instance.

* WIP: integrated into partition_parameters

Performance improvement of 5.1B bert on aws-p4d:
fwd: 300ms -> 200ms
bwd: 680ms -> 610ms

* Fix format

* cleaned dead code, modified unit test

* removed customized c++ extension

revert back to use torch distributed API

* change torch.ones to torch empty

* typo

* warn if not cuda tensor for allgather

* fix formatting

* fix: move ds_tensor to cuda device

but it is strange that the ds_tensor haven't been moved to cuda

* remove try clause on the path for fetching params

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
  • Loading branch information
3 people committed Oct 31, 2021
1 parent 7f5a3ad commit c0eeb69
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 15 deletions.
120 changes: 106 additions & 14 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import itertools

import torch
from torch.distributed.distributed_c10d import _get_global_rank
from torch.distributed.distributed_c10d import _get_global_rank, group
import torch.distributed as dist

from .linear import LinearModuleForZeroStage3, LinearFunctionForZeroStage3
from .offload_constants import *
Expand Down Expand Up @@ -496,6 +497,14 @@ def get_model():
assert isinstance(module, torch.nn.Module)
self._convert_to_zero_parameters(module.parameters(recurse=True))

self.use_all_gather_base = False
try:
from torch.distributed.distributed_c10d import _all_gather_base as all_gather
self.use_all_gather_base = True
except:
logger.info(
f"_all_gather_base API is not available in torch {torch.__version__}")

def _convert_to_zero_parameters(self, param_list):
for param in param_list:
if is_zero_param(param):
Expand Down Expand Up @@ -686,7 +695,9 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None):
all_gather_list.append(param)

if not async_op:
ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy)
# ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy)
ret_value = self._allgather_params_coalesced(all_gather_list, hierarchy)

for param in all_gather_list:
param.ds_status = ZeroParamStatus.AVAILABLE
return ret_value
Expand Down Expand Up @@ -732,8 +743,10 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):
see_memory_usage(
f'Before partitioning param {param.ds_id} {param.shape}',
force=False)

# param.data does not store anything meaningful in partitioned state
param.data = torch.ones(1, dtype=self.dtype).to(param.device)
param.data = torch.empty(1, dtype=self.dtype, device=param.device)

see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}',
force=False)

Expand All @@ -754,7 +767,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):
numel=partition_size):
final_location = OFFLOAD_NVME_DEVICE
buffer = self.param_swapper.get_buffer(param, partition_size)
partitioned_tensor = torch.zeros(1,
partitioned_tensor = torch.empty(1,
dtype=param.dtype,
device=buffer.device)
partitioned_tensor.data = buffer.data
Expand All @@ -763,7 +776,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):
)

else:
partitioned_tensor = torch.zeros(
partitioned_tensor = torch.empty(
partition_size,
dtype=param.dtype,
device=OFFLOAD_CPU_DEVICE
Expand Down Expand Up @@ -873,22 +886,101 @@ def _allgather_param(self, param, async_op=False, hierarchy=0):
# param.ds_numel).view(param.ds_shape)
# param.data = replicated_tensor.data
# return None
partitions = []
for i in range(self.world_size):
partitions.append(flat_tensor.narrow(0, partition_size * i, partition_size))
if self.use_all_gather_base:
# try the _all_gather_base on PyTorch master branch
handle = dist._all_gather_base(flat_tensor,
param.ds_tensor,
group=self.ds_process_group,
async_op=async_op)
else:
partitions = []
for i in range(self.world_size):
partitions.append(
flat_tensor.narrow(0,
partition_size * i,
partition_size))

if i == torch.distributed.get_rank(group=self.ds_process_group):
partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True)
if i == dist.get_rank(group=self.ds_process_group):
partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True)

handle = torch.distributed.all_gather(partitions,
partitions[self.rank],
group=self.ds_process_group,
async_op=async_op)
handle = dist.all_gather(partitions,
partitions[self.rank],
group=self.ds_process_group,
async_op=async_op)

replicated_tensor = flat_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape)
param.data = replicated_tensor.data
return handle

def _allgather_params_coalesced(self, param_list, hierarchy=0):
""" blocking call
avoid explicit memory copy in _allgather_params
"""
if len(param_list) == 0:
return
# collect local tensors and partition sizes
partition_sizes = []
local_tensors = []
for param in param_list:
partition_sizes.append(param.ds_tensor.ds_numel)
local_tensors.append(param.ds_tensor.cuda())

# allocate memory for allgather params
allgather_params = []
for psize in partition_sizes:
tensor_size = psize * self.world_size
flat_tensor = torch.empty(tensor_size,
dtype=param_list[0].dtype,
device=self.local_device).view(-1)
flat_tensor.requires_grad = False
allgather_params.append(flat_tensor)

# launch
launch_handles = []
# backend = get_backend(self.ds_process_group)
# with _batch_p2p_manager(backend):
for param_idx, param in enumerate(param_list):
input_tensor = local_tensors[param_idx].view(-1)

if self.use_all_gather_base:
# try the _all_gather_base from Pytorch master
h = dist._all_gather_base(allgather_params[param_idx],
input_tensor,
group=self.ds_process_group,
async_op=True)
else:
output_list = []
for i in range(self.world_size):
psize = partition_sizes[param_idx]
partition = allgather_params[param_idx].narrow(0, i * psize, psize)
output_list.append(partition)
if not partition.is_cuda:
logger.warning(
f'param {param_idx}, partition {i} is not on CUDA, partition shape {partition.size()}'
)

# back to old all_gather function signature
h = dist.all_gather(output_list,
input_tensor,
group=self.ds_process_group,
async_op=True)
launch_handles.append(h)

# Wait ensures the operation is enqueued, but not necessarily complete.
launch_handles[-1].wait()

# assign to param.data (not copy)
for i, param in enumerate(param_list):
gathered_tensor = allgather_params[i]
param.data = gathered_tensor.narrow(0,
0,
param.ds_numel).view(param.ds_shape).data

# guarantee the communication to be completed
torch.cuda.synchronize()

return None

def _allgather_params(self, param_list, hierarchy=0):
if len(param_list) == 0:
return
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def fetch_sub_module(self, sub_module):
self.hierarchy += 1

# parameters are partitioned and need to be allgathered
self._all_gather(partitioned_params, async_op=True)
self._all_gather(partitioned_params, async_op=False)

# parameters are inflight and communication needs to be completed
if partitioned_params or params_in_flight:
Expand Down

0 comments on commit c0eeb69

Please sign in to comment.