Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZeRO3, improved parameter all-gather operation #1188

Merged
merged 39 commits into from
Oct 31, 2021
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
1e73e75
remove norm(), avoid memcpy after allgather
zarzen Jun 25, 2021
67b3db3
WIP: wrapped ncclAllgather as customized op in DS
zarzen Jun 30, 2021
70e681f
WIP: integrated into partition_parameters
zarzen Jun 30, 2021
81b4fc4
Fix format
zarzen Jul 1, 2021
8a14e43
Merge branch 'master' into impr_allgather_params
zarzen Jul 6, 2021
32c8fa7
cleaned dead code, modified unit test
zarzen Jul 6, 2021
c4728f5
Merge branch 'master' into impr_allgather_params
zarzen Jul 6, 2021
e075fd4
Merge branch 'master' into impr_allgather_params
tjruwase Jul 14, 2021
5208508
removed customized c++ extension
zarzen Jul 23, 2021
ffd3d3b
Merge remote-tracking branch 'origin/master' into impr_allgather_params
zarzen Jul 23, 2021
1ed96ce
change torch.ones to torch empty
zarzen Jul 23, 2021
220f2e0
Merge branch 'master' into impr_allgather_params
tjruwase Jul 27, 2021
8f65594
Merge branch 'master' into impr_allgather_params
tjruwase Jul 29, 2021
0e6d8e0
typo
zarzen Aug 10, 2021
691749f
Merge branch 'master' into impr_allgather_params
tjruwase Sep 7, 2021
88e750e
Merge branch 'master' into impr_allgather_params
tjruwase Sep 8, 2021
497ee7d
Merge branch 'master' into impr_allgather_params
tjruwase Sep 9, 2021
bd8839c
Merge branch 'master' into impr_allgather_params
tjruwase Sep 9, 2021
2582910
Merge branch 'master' into impr_allgather_params
tjruwase Sep 10, 2021
4ca0d39
Merge branch 'master' into impr_allgather_params
tjruwase Sep 11, 2021
56de9ad
Merge branch 'master' into impr_allgather_params
tjruwase Oct 5, 2021
056cf10
Merge branch 'master' into impr_allgather_params
tjruwase Oct 6, 2021
aac09cd
Merge branch 'master' into impr_allgather_params
tjruwase Oct 7, 2021
6201b29
Merge branch 'master' into impr_allgather_params
tjruwase Oct 12, 2021
50a9215
warn if not cuda tensor for allgather
zarzen Oct 15, 2021
c554a58
Merge branch 'master' into impr_allgather_params
jeffra Oct 15, 2021
b7e131d
Merge branch 'master' into impr_allgather_params
tjruwase Oct 21, 2021
813cb22
fix formatting
zarzen Oct 21, 2021
588d3d0
Merge branch 'master' into impr_allgather_params
tjruwase Oct 22, 2021
eb0a540
Merge branch 'master' into impr_allgather_params
tjruwase Oct 22, 2021
c092b78
fix: move ds_tensor to cuda device
zarzen Oct 22, 2021
e73809d
Merge branch 'master' into impr_allgather_params
tjruwase Oct 27, 2021
d1d3c28
Merge branch 'master' into impr_allgather_params
tjruwase Oct 27, 2021
62cb104
Merge branch 'master' into impr_allgather_params
tjruwase Oct 27, 2021
ab64b17
Merge branch 'master' into impr_allgather_params
tjruwase Oct 28, 2021
7a80172
remove try clause on the path for fetching params
zarzen Oct 30, 2021
f01dad8
Merge branch 'microsoft:master' into impr_allgather_params
zarzen Oct 30, 2021
524e609
Merge branch 'master' into impr_allgather_params
tjruwase Oct 30, 2021
d7fff58
Merge branch 'master' into impr_allgather_params
tjruwase Oct 31, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 98 additions & 14 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
import itertools

import torch
from torch.distributed.distributed_c10d import _get_global_rank
from torch.distributed.distributed_c10d import _get_global_rank, group

try:
from torch.distributed.distributed_c10d import _all_gather_base as all_gather
except:
from torch.distributed.distributed_c10d import all_gather

from .linear import LinearModuleForZeroStage3, LinearFunctionForZeroStage3
from .offload_constants import *
Expand Down Expand Up @@ -678,7 +683,9 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None):
all_gather_list.append(param)

if not async_op:
ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy)
# ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy)
ret_value = self._allgather_params_coalesced(all_gather_list, hierarchy)

for param in all_gather_list:
param.ds_status = ZeroParamStatus.AVAILABLE
return ret_value
Expand Down Expand Up @@ -724,8 +731,10 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):
see_memory_usage(
f'Before partitioning param {param.ds_id} {param.shape}',
force=False)

# param.data does not store anything meaningful in partitioned state
param.data = torch.ones(1, dtype=self.dtype).to(param.device)
param.data = torch.empty(1, dtype=self.dtype, device=param.device)

see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}',
force=False)

Expand All @@ -746,7 +755,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):
numel=partition_size):
final_location = OFFLOAD_NVME_DEVICE
buffer = self.param_swapper.get_buffer(param, partition_size)
partitioned_tensor = torch.zeros(1,
partitioned_tensor = torch.empty(1,
dtype=param.dtype,
device=buffer.device)
partitioned_tensor.data = buffer.data
Expand All @@ -755,7 +764,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):
)

else:
partitioned_tensor = torch.zeros(
partitioned_tensor = torch.empty(
partition_size,
dtype=param.dtype,
device=OFFLOAD_CPU_DEVICE
Expand Down Expand Up @@ -865,22 +874,97 @@ def _allgather_param(self, param, async_op=False, hierarchy=0):
# param.ds_numel).view(param.ds_shape)
# param.data = replicated_tensor.data
# return None
partitions = []
for i in range(self.world_size):
partitions.append(flat_tensor.narrow(0, partition_size * i, partition_size))
try:
zarzen marked this conversation as resolved.
Show resolved Hide resolved
# try the _all_gather_base on PyTorch master branch
handle = all_gather(flat_tensor,
param.ds_tensor,
group=self.ds_process_group,
async_op=async_op)
except:
zarzen marked this conversation as resolved.
Show resolved Hide resolved
partitions = []
for i in range(self.world_size):
partitions.append(
flat_tensor.narrow(0,
partition_size * i,
partition_size))

if i == torch.distributed.get_rank(group=self.ds_process_group):
partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True)
if i == torch.distributed.get_rank(group=self.ds_process_group):
partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True)

handle = torch.distributed.all_gather(partitions,
partitions[self.rank],
group=self.ds_process_group,
async_op=async_op)
handle = torch.distributed.all_gather(partitions,
partitions[self.rank],
group=self.ds_process_group,
async_op=async_op)

replicated_tensor = flat_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape)
param.data = replicated_tensor.data
return handle

def _allgather_params_coalesced(self, param_list, hierarchy=0):
""" blocking call
avoid explicit memory copy in _allgather_params
"""
if len(param_list) == 0:
return
# collect local tensors and partition sizes
partition_sizes = []
local_tensors = []
for param in param_list:
partition_sizes.append(param.ds_tensor.ds_numel)
local_tensors.append(param.ds_tensor)

# allocate memory for allgather params
allgather_params = []
for psize in partition_sizes:
tensor_size = psize * self.world_size
flat_tensor = torch.empty(tensor_size,
dtype=param_list[0].dtype,
device=self.local_device).view(-1)
flat_tensor.requires_grad = False
allgather_params.append(flat_tensor)

# launch
launch_handles = []
# backend = get_backend(self.ds_process_group)
# with _batch_p2p_manager(backend):
for param_idx, param in enumerate(param_list):
input_tensor = local_tensors[param_idx].view(-1)

try:
# try the _all_gather_base from Pytorch master
h = all_gather(allgather_params[param_idx],
input_tensor,
group=self.ds_process_group,
async_op=True)
Copy link
Collaborator

@jeffra jeffra Oct 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems one of the errors is coming from this line that results in the

RuntimeError: Invalid function argument. Expected parameter `tensor_list` to be of type List[torch.Tensor].

allgather_params[param_idx] i think in this case is just a Tensor and not a list of Tensors?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading more on the context here, sorry. I see you're trying to use _all_gather_base, what version of pytorch was this introduced? The CI runs in question here are running with torch 1.8.2 LTS.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_all_gather_base is available on pytorch master, probably version 1.10+
that's why I have used a try clause to make the all-gather falls back to distributed.all_gather call.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeffra @tjruwase
hey sorry for the late reply, I just confirmed this PR is able to work with pytorch-1.8.0.
And I checked the log again at the failure, where it indeed raise RuntimeError at line 935, while this exception is caught by the except clause, in which it creates a list of Tensors for all-gather API.

except:
output_list = []
for i in range(self.world_size):
psize = partition_sizes[param_idx]
partition = allgather_params[param_idx].narrow(0, i * psize, psize)
output_list.append(partition)

# back to old all_gather function signature
h = all_gather(output_list,
input_tensor,
group=self.ds_process_group,
async_op=True)
launch_handles.append(h)

# Wait ensures the operation is enqueued, but not necessarily complete.
launch_handles[-1].wait()

# assign to param.data (not copy)
for i, param in enumerate(param_list):
gathered_tensor = allgather_params[i]
param.data = gathered_tensor.narrow(0,
0,
param.ds_numel).view(param.ds_shape).data

# guarantee the communication to be completed
torch.cuda.synchronize()

return None

def _allgather_params(self, param_list, hierarchy=0):
if len(param_list) == 0:
return
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def fetch_sub_module(self, sub_module):
self.hierarchy += 1

# parameters are partitioned and need to be allgathered
self._all_gather(partitioned_params, async_op=True)
self._all_gather(partitioned_params, async_op=False)

# parameters are inflight and communication needs to be completed
if partitioned_params or params_in_flight:
Expand Down