Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zero3 performance optimizations #3622

Merged
merged 13 commits into from
Jun 8, 2023
4 changes: 3 additions & 1 deletion deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import functools
import itertools
from typing import List
import logging
import torch
from torch import Tensor
from deepspeed import comm as dist
Expand Down Expand Up @@ -898,7 +899,8 @@ def all_gather_coalesced(params: Iterable[Parameter], safe_mode: bool = False) -
# to debug correctness issues.
params = sorted(params, key=lambda p: p.ds_id)

debug_rank0(f"-allgather_coalesced: {[p.ds_id for p in params]}")
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(f"-allgather_coalesced: {[p.ds_id for p in params]}")

if safe_mode:
# ensure that same list (with same ordering) of parameters are
Expand Down
32 changes: 19 additions & 13 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus
from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id
from deepspeed.accelerator import get_accelerator
import logging


def debug_rank0(message: str) -> None:
Expand Down Expand Up @@ -235,25 +236,28 @@ def fetch_sub_module(self, current_submodule: Module) -> None:
2. kick off fetch for next few parameters we will need later (prefetch)
3. block on parameters in immediately required sub module
"""
debug_rank0(
f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule)]} "
+ str({
"avail": f"{self.__n_available_params:.1e}",
"queue_sz": f"{len(self.__param_queue or [])}",
"inflight": [p.ds_id for p in self.__inflight_param_registry],
}))
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(
f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule)]} "
+ str({
"avail": f"{self.__n_available_params:.1e}",
"queue_sz": f"{len(self.__param_queue or [])}",
"inflight": [p.ds_id for p in self.__inflight_param_registry],
}))

params_to_fetch = frozenset(iter_params(current_submodule))

# kick off all gather for params in the immediately required submodule
for param in params_to_fetch:
debug_rank0(f"-fetch: {param.ds_summary()}")
if logger.isEnabledFor(logging.DEBUG):
for param in params_to_fetch:
debug_rank0(f"-fetch: {param.ds_summary()}")
self.__all_gather_params(params_to_fetch)

# wait for parameters in the immediately needed submodule to become available
for param in params_to_fetch:
param.ds_active_sub_modules.add(current_submodule.id)
debug_rank0(f"-wait: {param.ds_summary()}")
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(f"-wait: {param.ds_summary()}")
if param in self.__inflight_param_registry:
with get_accelerator().stream(self.__allgather_stream):
while self.__ongoing_fetch_events and self.__ongoing_fetch_events[0].query():
Expand Down Expand Up @@ -328,8 +332,9 @@ def _is_currently_on_nvme(param):
params_to_prefetch.add(param_in_trace.param)
numel_prefetching += param_in_trace.param.ds_numel

for param in params_to_prefetch:
debug_rank0(f"-prefetch: {param.ds_summary()}")
if logger.isEnabledFor(logging.DEBUG):
for param in params_to_prefetch:
debug_rank0(f"-prefetch: {param.ds_summary()}")
self.__all_gather_params(params_to_prefetch)

if self.__prefetch_nvme:
Expand Down Expand Up @@ -394,7 +399,8 @@ def __all_gather_params(self, params: Set[Parameter]) -> None:
@instrument_w_nvtx
def __release_param(self, param: Parameter) -> None:
if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules:
debug_rank0(f"-release: {param.ds_summary()}")
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(f"-release: {param.ds_summary()}")
param.partition()
self.__n_available_params -= param.ds_numel

Expand Down
106 changes: 66 additions & 40 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ def __init__(self,
if self.swap_optimizer:
self._configure_tensor_swapping(offload_optimizer_config, aio_config)

self.params_in_ipg_bucket = []
self.is_gradient_accumulation_boundary: bool = True

self.param_reduce_events: Deque[get_accelerator().Event] = collections.deque()
Expand All @@ -277,7 +276,6 @@ def __init__(self,
self.grads_in_ipg_bucket = []
self.params_in_ipg_bucket = []

self.params_already_reduced = []
self.is_gradient_accumulation_boundary = True
self._release_ipg_buffers()
self.previous_reduced_grads = None
Expand All @@ -291,7 +289,6 @@ def __init__(self,
unique_id = id(param)
self.param_id[unique_id] = count
self.param_dict[count] = param
self.params_already_reduced.append(False)
count = count + 1

#Largest partitioned param
Expand All @@ -307,7 +304,6 @@ def __init__(self,

if self.offload_optimizer:
self.norm_for_param_grads = {}
self.local_overflow = False

# stores if a partition has been reduced in this step
self.is_partition_reduced = {}
Expand Down Expand Up @@ -397,20 +393,20 @@ def _setup_for_real_optimizer(self):
dtype=self.dtype,
device=get_accelerator().current_device_name())

grad_partitions_flat_buffer = None
self.grad_partitions_flat_buffer = None
self.__param_id_to_grad_partition: Dict[int, Tensor] = {}

all_params = list(itertools.chain.from_iterable(self.fp16_groups))

grad_partitions_flat_buffer: Tensor = torch.zeros(sum(p.partition_numel() for p in all_params),
dtype=self.dtype,
device=self.device)
self.grad_partitions_flat_buffer: Tensor = torch.zeros(sum(p.partition_numel() for p in all_params),
dtype=self.dtype,
device=self.device)
if self.offload_optimizer_pin_memory:
grad_partitions_flat_buffer = get_accelerator().pin_memory(grad_partitions_flat_buffer)
self.grad_partitions_flat_buffer = get_accelerator().pin_memory(self.grad_partitions_flat_buffer)

offset = 0
for param in all_params:
self.__param_id_to_grad_partition[param.ds_id] = grad_partitions_flat_buffer.narrow(
self.__param_id_to_grad_partition[param.ds_id] = self.grad_partitions_flat_buffer.narrow(
0, offset, param.partition_numel())
offset += param.partition_numel()

Expand Down Expand Up @@ -966,11 +962,6 @@ def independent_gradient_partition_epilogue(self):

self.reduce_and_partition_stream.synchronize()

# if dist.get_rank() == 0:
# logger.info("Params already reduced %s", self.params_already_reduced)
for i in range(len(self.params_already_reduced)):
self.params_already_reduced[i] = False

#in case of cpu offload, averaged gradients are already in fp32_partitioned_groups_flat.grad
#TODO: use a similar code path for both cpu_offload and non-cpu offload
if not self.offload_optimizer:
Expand Down Expand Up @@ -1045,18 +1036,11 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
# 0). Otherwise if the incoming param.ds_numel is large, this branch may get triggered on a
# garbage data and `self.average_tensor()` will crash because its params_to_reduce will be
# empty, while reduction_list will have that garbage data.
if self.elements_in_ipg_bucket > 0 and self.elements_in_ipg_bucket + param.ds_numel > self.reduce_bucket_size:
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
if self.elements_in_ipg_bucket + param.ds_numel > self.reduce_bucket_size and self.elements_in_ipg_bucket > 0:
self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.ds_numel)

self.__reduce_and_partition_ipg_grads()

param_id = self.get_param_id(param)

assert self.params_already_reduced[param_id] == False, \
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
f"The parameter {param_id} has already been reduced. \
Gradient computed twice for this partition. \
Multiple gradient reduction is currently not supported"

self.__add_grad_to_ipg_bucket(param)

@instrument_w_nvtx
Expand Down Expand Up @@ -1087,8 +1071,6 @@ def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None:
raise RuntimeError(f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter "
f"gradients whose size is not same as the params")

self.params_in_ipg_bucket.sort(key=lambda p: p.ds_id)
tjruwase marked this conversation as resolved.
Show resolved Hide resolved

assert len(set(p.ds_id for p in self.params_in_ipg_bucket)) == len(self.params_in_ipg_bucket)

while self.param_reduce_events and self.param_reduce_events[0].query():
Expand All @@ -1100,7 +1082,13 @@ def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None:
if safe_mode:
assert_ints_same_as_other_ranks([p.ds_id for p in self.params_in_ipg_bucket])

grad_partitions = self.__avg_scatter_grads(self.params_in_ipg_bucket)
if self.contiguous_gradients and not self.reduce_scatter:
grad_bucket = self.__ipg_bucket_flat_buffer.narrow(0, 0, self.elements_in_ipg_bucket)
grad_partitions = self.__avg_scatter_contiguous_grads(grad_bucket)
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
else:
self.params_in_ipg_bucket.sort(key=lambda p: p.ds_id)
grad_partitions = self.__avg_scatter_grads(self.params_in_ipg_bucket)

self.partition_grads(self.params_in_ipg_bucket, grad_partitions)

self.params_in_ipg_bucket.clear()
Expand All @@ -1109,6 +1097,47 @@ def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None:
event.record()
self.param_reduce_events.append(event)

@instrument_w_nvtx
def __avg_scatter_contiguous_grads(self, buffer_to_reduce: Tensor) -> List[Tensor]:
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
dtype = buffer_to_reduce.dtype
if self.communication_data_type == self.dtype:
buffer_to_reduce = buffer_to_reduce.to(self.communication_data_type)
if self.postscale_gradients and self.gradient_predivide_factor != 1.0:
buffer_to_reduce = buffer_to_reduce.div_(self.gradient_predivide_factor)

world_sz = dist.get_world_size(self.dp_process_group)
rank = dist.get_rank(self.dp_process_group)
buffer_to_reduce.div_(world_sz)

dist.all_reduce(buffer_to_reduce, group=self.dp_process_group)

if self.postscale_gradients and self.gradient_predivide_factor != world_sz:
buffer_to_reduce = buffer_to_reduce.mul(self.gradient_predivide_factor)

if self.communication_data_type != self.dtype:
buffer_to_reduce = buffer_to_reduce.to(self.dtype)

grad_partitions = []
grad_offset_in_buffer = 0
for param in self.params_in_ipg_bucket:
grad = param.grad
chunk_sz = math.ceil(grad.numel() / world_sz)

start_offset = grad_offset_in_buffer + min(rank * chunk_sz, grad.numel())
end_offset = grad_offset_in_buffer + min(rank * chunk_sz + chunk_sz, grad.numel())

partition = buffer_to_reduce[start_offset:end_offset]
if param.partition_numel() != partition.numel():
padded_partition = torch.empty(param.partition_numel(), device=grad.device, dtype=grad.dtype)
if partition.numel() > 0:
padded_partition[:partition.numel()] = partition
grad_partitions.append(padded_partition)
else:
grad_partitions.append(partition)
grad_offset_in_buffer += grad.numel()

return grad_partitions

@instrument_w_nvtx
def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor]:
"""average gradients and scatter partitions across ranks"""
Expand Down Expand Up @@ -1222,15 +1251,6 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L
# operations and so it can be used asynchronously
grad_buffer = cuda_grad_buffer

if hasattr(self.inf_or_nan_tracker, "logical_or_"):
self.inf_or_nan_tracker.logical_or_(torch.isinf(grad_buffer).any())
self.inf_or_nan_tracker.logical_or_(torch.isnan(grad_buffer).any())
else:
# logical_or_ not available in older versions of pytorch
self.inf_or_nan_tracker += torch.isinf(grad_buffer).any()
self.inf_or_nan_tracker += torch.isnan(grad_buffer).any()
self.inf_or_nan_tracker = self.inf_or_nan_tracker > 0

# offload the gradient partition if applicable
if self.offload_optimizer:
i, dest_offset, _ = self.grad_position[self.get_param_id(param)]
Expand Down Expand Up @@ -1561,7 +1581,6 @@ def free_grad_in_param_list(self, param_list):

def reset_cpu_buffers(self):
self.norm_for_param_grads = {}
self.local_overflow = False

def log_timers(self, timer_names):
if self.timers is None:
Expand Down Expand Up @@ -1895,12 +1914,19 @@ def has_overflow_partitioned_grads_serial(self):
def has_overflow(self, partition_gradients=True):
if partition_gradients:
with get_accelerator().stream(self.reduce_and_partition_stream):
self.local_overflow = bool(self.inf_or_nan_tracker.item())
if hasattr(self.inf_or_nan_tracker, "logical_or_"):
self.inf_or_nan_tracker.logical_or_(torch.isinf(self.grad_partitions_flat_buffer).any())
self.inf_or_nan_tracker.logical_or_(torch.isnan(self.grad_partitions_flat_buffer).any())
else:
# logical_or_ not available in older versions of pytorch
self.inf_or_nan_tracker += torch.isinf(self.grad_partitions_flat_buffer).any()
self.inf_or_nan_tracker += torch.isnan(self.grad_partitions_flat_buffer).any()
self.inf_or_nan_tracker = self.inf_or_nan_tracker > 0

overflow_gpu = self.inf_or_nan_tracker.clone().to(torch.uint8)
self.inf_or_nan_tracker.zero_()

overflow = self.local_overflow
#overflow = self.has_overflow_partitioned_grads_serial()
overflow_gpu = get_accelerator().ByteTensor([overflow])
get_accelerator().default_stream().wait_stream(self.reduce_and_partition_stream)
dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group)

else:
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/runtime/zero/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,10 +736,11 @@ def create_tensor(vals, dtype: torch.dtype = None) -> Tensor:


@pytest.mark.parametrize("init_context_manager", [True, False])
@pytest.mark.parametrize("reduce_scatter", [True, False])
class TestZero3ParamPartitioningLargeParam(DistributedTest):
world_size = 4

def test(self, init_context_manager: bool, param_sz: int = 8100) -> None:
def test(self, init_context_manager: bool, reduce_scatter: bool, param_sz: int = 8100) -> None:

class LargeParamModel(Module):

Expand Down Expand Up @@ -767,6 +768,7 @@ def forward(self, x: Tensor) -> Tensor:
"stage3_max_reuse_distance": 0,
"contiguous_gradients": True,
"overlap_comm": True,
"reduce_scatter": reduce_scatter,
},
"optimizer": {
"type": "Adam",
Expand Down