Skip to content

Commit

Permalink
zero3 performance optimizations (#3622)
Browse files Browse the repository at this point in the history
* Remove dead code

params_already_reduced is not used

* Prevent evaluation of debug strings

Debug strings are evaluated even when logging is disabled

* Use contiguous gradients tensor reduce scatter between ranks

Use allreduce instead of reduce scatter. lower cpu overhead.

* move overflow tracker to optimizer.step

Don't check overflow in gradients for every bucket.
Do overflow chack once on grad flat buffer just before optimizer step

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
  • Loading branch information
BacharL and tjruwase authored Jun 8, 2023
1 parent df42509 commit 0977106
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 55 deletions.
4 changes: 3 additions & 1 deletion deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import functools
import itertools
from typing import List
import logging
import torch
from torch import Tensor
from deepspeed import comm as dist
Expand Down Expand Up @@ -898,7 +899,8 @@ def all_gather_coalesced(params: Iterable[Parameter], safe_mode: bool = False) -
# to debug correctness issues.
params = sorted(params, key=lambda p: p.ds_id)

debug_rank0(f"-allgather_coalesced: {[p.ds_id for p in params]}")
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(f"-allgather_coalesced: {[p.ds_id for p in params]}")

if safe_mode:
# ensure that same list (with same ordering) of parameters are
Expand Down
32 changes: 19 additions & 13 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus
from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id
from deepspeed.accelerator import get_accelerator
import logging


def debug_rank0(message: str) -> None:
Expand Down Expand Up @@ -235,25 +236,28 @@ def fetch_sub_module(self, current_submodule: Module) -> None:
2. kick off fetch for next few parameters we will need later (prefetch)
3. block on parameters in immediately required sub module
"""
debug_rank0(
f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule)]} "
+ str({
"avail": f"{self.__n_available_params:.1e}",
"queue_sz": f"{len(self.__param_queue or [])}",
"inflight": [p.ds_id for p in self.__inflight_param_registry],
}))
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(
f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule)]} "
+ str({
"avail": f"{self.__n_available_params:.1e}",
"queue_sz": f"{len(self.__param_queue or [])}",
"inflight": [p.ds_id for p in self.__inflight_param_registry],
}))

params_to_fetch = frozenset(iter_params(current_submodule))

# kick off all gather for params in the immediately required submodule
for param in params_to_fetch:
debug_rank0(f"-fetch: {param.ds_summary()}")
if logger.isEnabledFor(logging.DEBUG):
for param in params_to_fetch:
debug_rank0(f"-fetch: {param.ds_summary()}")
self.__all_gather_params(params_to_fetch)

# wait for parameters in the immediately needed submodule to become available
for param in params_to_fetch:
param.ds_active_sub_modules.add(current_submodule.id)
debug_rank0(f"-wait: {param.ds_summary()}")
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(f"-wait: {param.ds_summary()}")
if param in self.__inflight_param_registry:
with get_accelerator().stream(self.__allgather_stream):
while self.__ongoing_fetch_events and self.__ongoing_fetch_events[0].query():
Expand Down Expand Up @@ -328,8 +332,9 @@ def _is_currently_on_nvme(param):
params_to_prefetch.add(param_in_trace.param)
numel_prefetching += param_in_trace.param.ds_numel

for param in params_to_prefetch:
debug_rank0(f"-prefetch: {param.ds_summary()}")
if logger.isEnabledFor(logging.DEBUG):
for param in params_to_prefetch:
debug_rank0(f"-prefetch: {param.ds_summary()}")
self.__all_gather_params(params_to_prefetch)

if self.__prefetch_nvme:
Expand Down Expand Up @@ -394,7 +399,8 @@ def __all_gather_params(self, params: Set[Parameter]) -> None:
@instrument_w_nvtx
def __release_param(self, param: Parameter) -> None:
if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules:
debug_rank0(f"-release: {param.ds_summary()}")
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(f"-release: {param.ds_summary()}")
param.partition()
self.__n_available_params -= param.ds_numel

Expand Down
106 changes: 66 additions & 40 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ def __init__(self,
if self.swap_optimizer:
self._configure_tensor_swapping(offload_optimizer_config, aio_config)

self.params_in_ipg_bucket = []
self.is_gradient_accumulation_boundary: bool = True

self.param_reduce_events: Deque[get_accelerator().Event] = collections.deque()
Expand All @@ -277,7 +276,6 @@ def __init__(self,
self.grads_in_ipg_bucket = []
self.params_in_ipg_bucket = []

self.params_already_reduced = []
self.is_gradient_accumulation_boundary = True
self._release_ipg_buffers()
self.previous_reduced_grads = None
Expand All @@ -291,7 +289,6 @@ def __init__(self,
unique_id = id(param)
self.param_id[unique_id] = count
self.param_dict[count] = param
self.params_already_reduced.append(False)
count = count + 1

#Largest partitioned param
Expand All @@ -307,7 +304,6 @@ def __init__(self,

if self.offload_optimizer:
self.norm_for_param_grads = {}
self.local_overflow = False

# stores if a partition has been reduced in this step
self.is_partition_reduced = {}
Expand Down Expand Up @@ -397,20 +393,20 @@ def _setup_for_real_optimizer(self):
dtype=self.dtype,
device=get_accelerator().current_device_name())

grad_partitions_flat_buffer = None
self.grad_partitions_flat_buffer = None
self.__param_id_to_grad_partition: Dict[int, Tensor] = {}

all_params = list(itertools.chain.from_iterable(self.fp16_groups))

grad_partitions_flat_buffer: Tensor = torch.zeros(sum(p.partition_numel() for p in all_params),
dtype=self.dtype,
device=self.device)
self.grad_partitions_flat_buffer: Tensor = torch.zeros(sum(p.partition_numel() for p in all_params),
dtype=self.dtype,
device=self.device)
if self.offload_optimizer_pin_memory:
grad_partitions_flat_buffer = get_accelerator().pin_memory(grad_partitions_flat_buffer)
self.grad_partitions_flat_buffer = get_accelerator().pin_memory(self.grad_partitions_flat_buffer)

offset = 0
for param in all_params:
self.__param_id_to_grad_partition[param.ds_id] = grad_partitions_flat_buffer.narrow(
self.__param_id_to_grad_partition[param.ds_id] = self.grad_partitions_flat_buffer.narrow(
0, offset, param.partition_numel())
offset += param.partition_numel()

Expand Down Expand Up @@ -966,11 +962,6 @@ def independent_gradient_partition_epilogue(self):

self.reduce_and_partition_stream.synchronize()

# if dist.get_rank() == 0:
# logger.info("Params already reduced %s", self.params_already_reduced)
for i in range(len(self.params_already_reduced)):
self.params_already_reduced[i] = False

#in case of cpu offload, averaged gradients are already in fp32_partitioned_groups_flat.grad
#TODO: use a similar code path for both cpu_offload and non-cpu offload
if not self.offload_optimizer:
Expand Down Expand Up @@ -1045,18 +1036,11 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
# 0). Otherwise if the incoming param.ds_numel is large, this branch may get triggered on a
# garbage data and `self.average_tensor()` will crash because its params_to_reduce will be
# empty, while reduction_list will have that garbage data.
if self.elements_in_ipg_bucket > 0 and self.elements_in_ipg_bucket + param.ds_numel > self.reduce_bucket_size:
if self.elements_in_ipg_bucket + param.ds_numel > self.reduce_bucket_size and self.elements_in_ipg_bucket > 0:
self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.ds_numel)

self.__reduce_and_partition_ipg_grads()

param_id = self.get_param_id(param)

assert self.params_already_reduced[param_id] == False, \
f"The parameter {param_id} has already been reduced. \
Gradient computed twice for this partition. \
Multiple gradient reduction is currently not supported"

self.__add_grad_to_ipg_bucket(param)

@instrument_w_nvtx
Expand Down Expand Up @@ -1087,8 +1071,6 @@ def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None:
raise RuntimeError(f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter "
f"gradients whose size is not same as the params")

self.params_in_ipg_bucket.sort(key=lambda p: p.ds_id)

assert len(set(p.ds_id for p in self.params_in_ipg_bucket)) == len(self.params_in_ipg_bucket)

while self.param_reduce_events and self.param_reduce_events[0].query():
Expand All @@ -1100,7 +1082,13 @@ def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None:
if safe_mode:
assert_ints_same_as_other_ranks([p.ds_id for p in self.params_in_ipg_bucket])

grad_partitions = self.__avg_scatter_grads(self.params_in_ipg_bucket)
if self.contiguous_gradients and not self.reduce_scatter:
grad_bucket = self.__ipg_bucket_flat_buffer.narrow(0, 0, self.elements_in_ipg_bucket)
grad_partitions = self.__avg_scatter_contiguous_grads(grad_bucket)
else:
self.params_in_ipg_bucket.sort(key=lambda p: p.ds_id)
grad_partitions = self.__avg_scatter_grads(self.params_in_ipg_bucket)

self.partition_grads(self.params_in_ipg_bucket, grad_partitions)

self.params_in_ipg_bucket.clear()
Expand All @@ -1109,6 +1097,47 @@ def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None:
event.record()
self.param_reduce_events.append(event)

@instrument_w_nvtx
def __avg_scatter_contiguous_grads(self, buffer_to_reduce: Tensor) -> List[Tensor]:
dtype = buffer_to_reduce.dtype
if self.communication_data_type == self.dtype:
buffer_to_reduce = buffer_to_reduce.to(self.communication_data_type)
if self.postscale_gradients and self.gradient_predivide_factor != 1.0:
buffer_to_reduce = buffer_to_reduce.div_(self.gradient_predivide_factor)

world_sz = dist.get_world_size(self.dp_process_group)
rank = dist.get_rank(self.dp_process_group)
buffer_to_reduce.div_(world_sz)

dist.all_reduce(buffer_to_reduce, group=self.dp_process_group)

if self.postscale_gradients and self.gradient_predivide_factor != world_sz:
buffer_to_reduce = buffer_to_reduce.mul(self.gradient_predivide_factor)

if self.communication_data_type != self.dtype:
buffer_to_reduce = buffer_to_reduce.to(self.dtype)

grad_partitions = []
grad_offset_in_buffer = 0
for param in self.params_in_ipg_bucket:
grad = param.grad
chunk_sz = math.ceil(grad.numel() / world_sz)

start_offset = grad_offset_in_buffer + min(rank * chunk_sz, grad.numel())
end_offset = grad_offset_in_buffer + min(rank * chunk_sz + chunk_sz, grad.numel())

partition = buffer_to_reduce[start_offset:end_offset]
if param.partition_numel() != partition.numel():
padded_partition = torch.empty(param.partition_numel(), device=grad.device, dtype=grad.dtype)
if partition.numel() > 0:
padded_partition[:partition.numel()] = partition
grad_partitions.append(padded_partition)
else:
grad_partitions.append(partition)
grad_offset_in_buffer += grad.numel()

return grad_partitions

@instrument_w_nvtx
def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor]:
"""average gradients and scatter partitions across ranks"""
Expand Down Expand Up @@ -1223,15 +1252,6 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L
# operations and so it can be used asynchronously
grad_buffer = cuda_grad_buffer

if hasattr(self.inf_or_nan_tracker, "logical_or_"):
self.inf_or_nan_tracker.logical_or_(torch.isinf(grad_buffer).any())
self.inf_or_nan_tracker.logical_or_(torch.isnan(grad_buffer).any())
else:
# logical_or_ not available in older versions of pytorch
self.inf_or_nan_tracker += torch.isinf(grad_buffer).any()
self.inf_or_nan_tracker += torch.isnan(grad_buffer).any()
self.inf_or_nan_tracker = self.inf_or_nan_tracker > 0

# offload the gradient partition if applicable
if self.offload_optimizer:
i, dest_offset, _ = self.grad_position[self.get_param_id(param)]
Expand Down Expand Up @@ -1567,7 +1587,6 @@ def free_grad_in_param_list(self, param_list):

def reset_cpu_buffers(self):
self.norm_for_param_grads = {}
self.local_overflow = False

def log_timers(self, timer_names):
if self.timers is None:
Expand Down Expand Up @@ -1901,12 +1920,19 @@ def has_overflow_partitioned_grads_serial(self):
def has_overflow(self, partition_gradients=True):
if partition_gradients:
with get_accelerator().stream(self.reduce_and_partition_stream):
self.local_overflow = bool(self.inf_or_nan_tracker.item())
if hasattr(self.inf_or_nan_tracker, "logical_or_"):
self.inf_or_nan_tracker.logical_or_(torch.isinf(self.grad_partitions_flat_buffer).any())
self.inf_or_nan_tracker.logical_or_(torch.isnan(self.grad_partitions_flat_buffer).any())
else:
# logical_or_ not available in older versions of pytorch
self.inf_or_nan_tracker += torch.isinf(self.grad_partitions_flat_buffer).any()
self.inf_or_nan_tracker += torch.isnan(self.grad_partitions_flat_buffer).any()
self.inf_or_nan_tracker = self.inf_or_nan_tracker > 0

overflow_gpu = self.inf_or_nan_tracker.clone().to(torch.uint8)
self.inf_or_nan_tracker.zero_()

overflow = self.local_overflow
#overflow = self.has_overflow_partitioned_grads_serial()
overflow_gpu = get_accelerator().ByteTensor([overflow])
get_accelerator().default_stream().wait_stream(self.reduce_and_partition_stream)
dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group)

else:
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/runtime/zero/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,10 +736,11 @@ def create_tensor(vals, dtype: torch.dtype = None) -> Tensor:


@pytest.mark.parametrize("init_context_manager", [True, False])
@pytest.mark.parametrize("reduce_scatter", [True, False])
class TestZero3ParamPartitioningLargeParam(DistributedTest):
world_size = 4

def test(self, init_context_manager: bool, param_sz: int = 8100) -> None:
def test(self, init_context_manager: bool, reduce_scatter: bool, param_sz: int = 8100) -> None:

class LargeParamModel(Module):

Expand Down Expand Up @@ -767,6 +768,7 @@ def forward(self, x: Tensor) -> Tensor:
"stage3_max_reuse_distance": 0,
"contiguous_gradients": True,
"overlap_comm": True,
"reduce_scatter": reduce_scatter,
},
"optimizer": {
"type": "Adam",
Expand Down

0 comments on commit 0977106

Please sign in to comment.