Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZeRO-3 Slowdown #1170

Closed
wants to merge 40 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
9a179e4
Fix tracing; Add timers; Disable norm()
tjruwase Jun 17, 2021
038da9d
Merge branch 'master' into olruwase/zero3_broken_tracing
tjruwase Jun 17, 2021
8d5c784
Merge branch 'master' into olruwase/zero3_broken_tracing
jeffra Jun 17, 2021
46eb232
Simplify prefetch code
tjruwase Jun 18, 2021
e93610c
Merge branch 'olruwase/zero3_broken_tracing' of github.com:microsoft/…
tjruwase Jun 18, 2021
6f7c1c2
Merge branch 'master' into olruwase/zero3_broken_tracing
tjruwase Jun 19, 2021
af66609
Move available parameter numbel book-keeping to param partitioner.
tjruwase Jun 19, 2021
7432867
Merge branch 'olruwase/zero3_broken_tracing' of github.com:microsoft/…
tjruwase Jun 19, 2021
11e6c8f
Merge with master
tjruwase Jun 30, 2021
eceb707
Format fixes
tjruwase Jun 30, 2021
bb43da0
Bug fix
tjruwase Jun 30, 2021
f3048f2
Avoid unneeded synch
tjruwase Jun 30, 2021
0514811
Remove assert
tjruwase Jun 30, 2021
c81422b
Merge branch 'master' into olruwase/zero3_broken_tracing
tjruwase Jul 1, 2021
e09b80a
Remove dead code
tjruwase Jul 1, 2021
f938a73
Merge branch 'master' of github.com:microsoft/DeepSpeed into olruwase…
tjruwase Jul 1, 2021
34ec7da
Merge branch 'olruwase/zero3_broken_tracing' of github.com:microsoft/…
tjruwase Jul 1, 2021
3c4a949
Bug fix
tjruwase Jul 2, 2021
a3d92f0
Merge branch 'master' into olruwase/zero3_broken_tracing
tjruwase Jul 14, 2021
bc65ea7
Merge branch 'master' into olruwase/zero3_broken_tracing
tjruwase Jul 16, 2021
43f4bb4
Merge branch 'master' into olruwase/zero3_broken_tracing
tjruwase Jul 27, 2021
f495465
Merge branch 'master' into olruwase/zero3_broken_tracing
tjruwase Jul 29, 2021
96c87d7
Merge branch 'master' into olruwase/zero3_broken_tracing
tjruwase Aug 27, 2021
b2549df
Merge branch 'master' into olruwase/zero3_broken_tracing
tjruwase Sep 7, 2021
c2e4826
Merge branch 'master' into olruwase/zero3_broken_tracing
tjruwase Sep 8, 2021
fc61a5a
Merge branch 'master' into olruwase/zero3_broken_tracing
tjruwase Sep 10, 2021
0ad2ffe
Merge branch 'master' into olruwase/zero3_broken_tracing
tjruwase Sep 10, 2021
c648f17
Merge branch 'master' into olruwase/zero3_broken_tracing
tjruwase Oct 1, 2021
f218446
Remove redudant conditional
tjruwase Oct 1, 2021
a75e465
Avoid multiple zero3 contexts
tjruwase Oct 9, 2021
ec3eb56
Merge
tjruwase Oct 9, 2021
d5e49d1
Merge with master
tjruwase Oct 9, 2021
5c9b188
Formatting fixes
tjruwase Oct 9, 2021
038993e
Restore tracing fixes
tjruwase Oct 10, 2021
e7832d6
Merge branch 'master' into olruwase/zero3_broken_tracing
tjruwase Oct 10, 2021
13c0532
Merge branch 'master' into olruwase/zero3_broken_tracing
tjruwase Oct 11, 2021
f4cebc0
Merge branch 'master' into olruwase/zero3_broken_tracing
tjruwase Oct 12, 2021
98fa12f
Merge branch 'master' into olruwase/zero3_broken_tracing
tjruwase Oct 14, 2021
05acc48
Merge branch 'master' into olruwase/zero3_broken_tracing
tjruwase Oct 19, 2021
30afb5b
Merge branch 'master' into olruwase/zero3_broken_tracing
tjruwase Oct 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 86 additions & 11 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,11 @@ def get_model():

self._validate_remote_device(remote_device, _ds_config)

self.available_parameter_numel = 0

# Remote device is the device where parameter partiitons are stored
# It can be same as local_device or it could be CPU or NVMe.
#It can be same as local_device or it could be CPU or NVMe.

self.remote_device = self.local_device if remote_device is None else remote_device
self.pin_memory = pin_memory if (
self.remote_device == OFFLOAD_CPU_DEVICE) else False
Expand All @@ -494,11 +497,14 @@ def get_model():
# If we are provided an already-allocated module to prepare.
if module is not None:
assert isinstance(module, torch.nn.Module)
for param in module.parameters(recurse=True):
if is_zero_param(param):
continue
self._convert_to_deepspeed_param(param)
param.partition()
self._convert_to_zero_parameters(module.parameters(recurse=True))

def _convert_to_zero_parameters(self, param_list):
for param in param_list:
if is_zero_param(param):
continue
self._convert_to_deepspeed_param(param)
param.partition()

def _validate_remote_device(self, remote_device, ds_config):
if ds_config is not None:
Expand Down Expand Up @@ -550,6 +556,11 @@ def _convert_to_deepspeed_param(self, param):
# Stores the number of elements in the original parameter without padding
param.ds_numel = param.numel()

# Update status book keeping
self._update_param_status(new_status=ZeroParamStatus.AVAILABLE,
old_status=ZeroParamStatus.NOT_AVAILABLE,
numel=param.ds_numel)

# Stores the partitioned copy of the tensor
param.ds_tensor = None

Expand Down Expand Up @@ -621,6 +632,25 @@ def padding_size():
def partitioned_size():
return self._partitioned_size(param)

def update_status(new_status, param_list=None, hierarchy=0):
cls = param
if param_list is None:
param_list = [cls]
self._update_status(param_list, new_status)

def get_available_parameter_numel():
return self._get_available_parameter_numel()

def synchronize_communication(param_list=None, handle_list=None, hierarchy=0):
cls = param
if param_list is None:
param_list = [cls]

self._synchronize_communication(param_list, handle_list)

def convert_to_zero_parameters(param_list):
self._convert_to_zero_parameters(param_list)

# Collectives for gathering and partitioning parameters
param.all_gather = all_gather
param.partition = partition
Expand All @@ -634,6 +664,13 @@ def partitioned_size():
param.padding_size = padding_size
param.partitioned_size = partitioned_size

# Status utilities
param.update_status = update_status
param.get_available_parameter_numel = get_available_parameter_numel

param.synchronize_communication = synchronize_communication
param.convert_to_zero_parameters = convert_to_zero_parameters

def _aligned_size(self, param):
return param.ds_numel + self._padding_size(param)

Expand Down Expand Up @@ -672,15 +709,24 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None):
handle = self._allgather_param(param,
async_op=async_op,
hierarchy=hierarchy)
param.ds_status = ZeroParamStatus.INFLIGHT # if async_op else ZeroParamStatus.AVAILABLE
param.update_status(ZeroParamStatus.INFLIGHT)
handles.append(handle)
else:
all_gather_list.append(param)

if not async_op:
ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy)
avail_params = []
status_params = []
for param in all_gather_list:
param.ds_status = ZeroParamStatus.AVAILABLE
avail_params.append(param.ds_id)
status_params.append(param.ds_status)
param.update_status(ZeroParamStatus.AVAILABLE)

print_rank_0(
f'_all_gather marks available params = {avail_params} status = {status_params}',
force=False)

return ret_value

return handles
Expand All @@ -690,8 +736,8 @@ def _partition(self, param_list, force=False, has_been_updated=False):
#print_rank_0(f"Before Partitioning Param {param.ds_id}")
# self._param_status(param)
self._partition_param(param, has_been_updated=has_been_updated)
param.ds_status = ZeroParamStatus.NOT_AVAILABLE
# if param.ds_tensor is not None:
param.update_status(ZeroParamStatus.NOT_AVAILABLE)
#if param.ds_tensor is not None:
# assert id(param.data) == id(param.ds_tensor.data), \
# "After the parameters are initially partitioned, make sure we are not recreating the partition."
#print_rank_0(f"After Partitioning Param {param.ds_id}")
Expand Down Expand Up @@ -854,7 +900,8 @@ def _allgather_param(self, param, async_op=False, hierarchy=0):
f'After allocate allgather param {debug_param2name_id_shape_status(param)} {aligned_param_size} {partition_size} ',
force=False)

torch.cuda.synchronize()
if not async_op:
torch.cuda.synchronize()

print_rank_0(
f"{'--'* hierarchy}----allgather param with {debug_param2name_id_shape_status(param)} partition size={partition_size}"
Expand Down Expand Up @@ -1088,6 +1135,34 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False):
param.grad.data = dest_tensor_full_buffer.data
see_memory_usage("After partitioning gradients", force=False)

def _update_status(self, param_list, new_status):
if len(param_list) == 0:
return

for param in param_list:
old_status = param.ds_status
param.ds_status = new_status
self._update_param_status(new_status, old_status, param.ds_numel)

def _update_param_status(self, new_status, old_status, numel):
if old_status == ZeroParamStatus.AVAILABLE:
self.available_parameter_numel -= numel

if new_status == ZeroParamStatus.AVAILABLE:
self.available_parameter_numel += numel

assert self.available_parameter_numel >= 0, f'available_parameter numel is negative: {self.available_parameter_numel}'

def _get_available_parameter_numel(self):
return self.available_parameter_numel

def _synchronize_communication(self, param_list, handle_list):
for param, handle in zip(param_list, handle_list):
if handle is not None:
handle.wait()
tjruwase marked this conversation as resolved.
Show resolved Hide resolved

self._update_status(param_list=param_list, new_status=ZeroParamStatus.AVAILABLE)


class GatheredParameters:
def __init__(self, params, modifier_rank=None, fwd_module=None, enabled=True):
Expand Down
Loading