-
Notifications
You must be signed in to change notification settings - Fork 4.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Various ZeRO Stage3 Optimizations + Improvements (including bfloat16 …
…support) (#1453) * Changes for bfloat16 Zero2 * ZeRO stage3 optimizations, with some bug fixes optimizations for stage3: - prefetching improvements - batching allgather calls to amortize fixed overhead and improve bandwidth utilization - batching reduce_scatter calls to amortize fixed overhead and improve bandwidth utilization - using *_base variants of allgather and reduce scatter to reduce memory allocations and data movement - more fine grained synchronization for communication that allows blocking on less work - precomputation of fetching code - using a fetch queue rather than deciding what to (pre)fetch at each iteration - limiting queued coalesced communication ops to reduce memory pressure on pytorch cuda caching allocator (not elegant solution) optimizations for stage3-offload: - made some host-device tensor copies async to improve performance bug fixes and qol improvements: - fix init context method when parent modules modify child weights - speed up model initialization by moving model to GPU before weight initialization - fixed unit test imports so that unit tests can be run from any directory - change performance logging to include memory consumption - add logging w/ model size when done partitioning model new features - bfloat16 support for ZeRO 3 * fix import in ut * ran yapf * improvements to cache flush warn log * backwards compatibility with older versions of pytorch * handle edge case where reduced tensor smaller than world size * moved event synchronization to allgather handle wait() call * removed unnecessary barrier call * formatting fix after resolving merge conflict * skip nvme prefetch when trace not complete * opportunistically avoid memory allocation in allgather coalesced where possible * fix indentation after merge * fixes to account for parameter offload * accounting for torch.cuda.memory_stats not being available * moved partition_all_params to optimizer step * allgathering on params before item gets called * fix param status checks needed after moving partition_all_parameters call to optimizer step * fix grad accumulation with optimizer offload * grad norm computation fix for optimizer offload * change post divide in reduce-scatter to pre divide * fix gradient race condition w/ optimizer offload * improve inf/nan gradient tracking * don't prefetch when not in training mode * format fix after merging * fix prefetching issue when using NVME offload * improved defragmentation for fp16 parameters * relative imports for bf16 tests * changes for bwd compatibility with pytorch 1.2 * remove buffered_reduce_fallback * removed unused parameter offset bookkeeping * fixed tracking for multiple param groups * unbroke bfloat16 config after merge conflict * using base allgather params when only 1 param * cleanup/fixes for fp16 partition defragmentation * switch to CRLF * convert to same new-line style as master * align new line with master * Fix merge issues * switch to CRLF * fix to LF line endings * minor merge fixes * remove extra bfloat16_enabled definition * asserting params inflight for AllGatherHandle * remove get_cuda_mem_allocated_str * Format fixes * fix bfloat16 zero stage check (broken after merge commit) * +self.communication_data_type, -self.allreduce_always_fp32; delete dead code * Add self.reduce_scatter * Format fix * Fix merge issues * iterate over params_to_fetch rather than make another iterator * add some TODOs * remove unnecessary division by micro_step_id * rename config keys "bfloat16" -> "bf16" * rename stage3_gather_fp16_weights_on_model_save -> stage3_gather_16bit_weights_on_model_save * add unit test to check backwards compatibility for gather_16bit_weights * added test to confirm bf16 key bwd compatibility * Format fixes Co-authored-by: Rana Ali Amjad <raamjad@amazon.com> Co-authored-by: Justin Chiu <justchiu@amazon.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
- Loading branch information
1 parent
2d51f61
commit 4912e0a
Showing
53 changed files
with
2,295 additions
and
1,057 deletions.
There are no files selected for viewing
Submodule DeepSpeedExamples
updated
50 files
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
"""batched collective operations for overhead amortization and better | ||
bandwidth utilization""" | ||
|
||
import math | ||
from typing import List | ||
|
||
import torch | ||
from torch import Tensor | ||
import torch.distributed | ||
from torch.distributed import ProcessGroup | ||
import torch.nn.functional | ||
|
||
from deepspeed.utils import instrument_w_nvtx | ||
from deepspeed.utils.logging import logger | ||
|
||
if hasattr(torch.distributed, "_reduce_scatter_base"): | ||
|
||
def torch_reduce_scatter_fn(input_tensor: Tensor, output_tensor: Tensor, group): | ||
instrument_w_nvtx(torch.distributed._reduce_scatter_base)( | ||
output_tensor, | ||
input_tensor, | ||
group=group, | ||
) | ||
else: | ||
logger.warning( | ||
"unable to find torch.distributed._reduce_scatter_base. will fall back to " | ||
"torch.distributed.reduce_scatter which will result in suboptimal performance. " | ||
"please consider upgrading your pytorch installation.") | ||
|
||
def torch_reduce_scatter_fn(input_tensor: Tensor, output_tensor: Tensor, group): | ||
input_tensor_lst = list( | ||
torch.chunk(input_tensor, | ||
torch.distributed.get_world_size(group))) | ||
instrument_w_nvtx(torch.distributed.reduce_scatter)( | ||
output_tensor, | ||
input_tensor_lst, | ||
group=group, | ||
) | ||
|
||
|
||
@instrument_w_nvtx | ||
@torch.no_grad() | ||
def reduce_scatter_coalesced( | ||
tensors: List[Tensor], | ||
group: ProcessGroup = None, | ||
) -> List[Tensor]: | ||
"""simultaneously reduce-scatter a list of tensors - this can be done more | ||
efficiently than individual reduce scatter calls | ||
TODO. see if PyTorch team wants a c++ verson of this for ProcessGroupNCCL | ||
""" | ||
this_rank = torch.distributed.get_rank(group) | ||
world_sz = torch.distributed.get_world_size(group) | ||
|
||
partition_lst_for_each_tensor = [None] * len(tensors) | ||
for tensor_idx, tensor in enumerate(tensors): | ||
flattened_tensor = tensor.view(-1) | ||
chunk_sz = math.ceil(tensor.numel() / world_sz) | ||
partition_lst_for_each_tensor[tensor_idx] = [ | ||
flattened_tensor[rank * chunk_sz:rank * chunk_sz + chunk_sz] | ||
for rank in range(0, | ||
world_sz) | ||
] | ||
|
||
padded_partition_sz_for_each_tensor = tuple( | ||
math.ceil(t.numel() / world_sz) for t in tensors) | ||
|
||
if len(tensors) == 1 and tensors[0].numel() % world_sz == 0: | ||
# if there's only one tensor being reduced and we don't need to pad | ||
# we have an opportunity to avoid a memory allocation | ||
tensor_partition_flat_buffer = tensors[0].view(-1) | ||
else: | ||
# interleave tensor partitions such that the correct reduced partitions of each tensor | ||
# end up at each rank | ||
tensor_partitions_lst_with_padding = [] | ||
for rank in range(world_sz): | ||
for tensor_idx in range(len(tensors)): | ||
# add tensor content | ||
tensor_chunk = partition_lst_for_each_tensor[tensor_idx][rank] | ||
tensor_partitions_lst_with_padding.append(tensor_chunk) | ||
|
||
# add padding if necessary | ||
padding_sz = padded_partition_sz_for_each_tensor[ | ||
tensor_idx] - tensor_chunk.numel() | ||
if padding_sz > 0: | ||
tensor_partitions_lst_with_padding.append( | ||
torch.empty(padding_sz, | ||
dtype=tensor_chunk.dtype, | ||
device=tensor_chunk.device)) | ||
|
||
tensor_partition_flat_buffer = instrument_w_nvtx( | ||
torch.cat)(tensor_partitions_lst_with_padding) | ||
|
||
tensor_partition_flat_buffer.div_(world_sz) # pre-divide | ||
tensor_partition_buffer_for_each_rank: List[Tensor] = torch.chunk( | ||
tensor_partition_flat_buffer, | ||
world_sz) | ||
|
||
# batched reduce-scatter call | ||
torch_reduce_scatter_fn(tensor_partition_flat_buffer, | ||
tensor_partition_buffer_for_each_rank[this_rank], | ||
group) | ||
|
||
# reverse procedure of the interleaving done previously, done on the | ||
# result of the batched reduce-scatter | ||
output_lst: List[Tensor] = [None] * len(tensors) | ||
offset = 0 | ||
for tensor_idx in range(len(tensors)): | ||
output_lst[tensor_idx] = tensor_partition_buffer_for_each_rank[this_rank].narrow( | ||
0, | ||
offset, | ||
partition_lst_for_each_tensor[tensor_idx][this_rank].numel()) | ||
|
||
offset += padded_partition_sz_for_each_tensor[tensor_idx] | ||
|
||
return output_lst |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.