Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various ZeRO Stage3 Optimizations + Improvements (including bfloat16 support) #1453

Merged
merged 91 commits into from
Jan 21, 2022
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
fe26423
Changes for bfloat16 Zero2
raamjad Aug 14, 2021
8864f91
ZeRO stage3 optimizations, with some bug fixes
Sep 29, 2021
e66aedc
fix import in ut
Oct 12, 2021
350a7a0
ran yapf
Oct 12, 2021
b37a4f0
Merge branch 'master' into s3-pr
tjruwase Oct 13, 2021
f383947
improvements to cache flush warn log
Oct 13, 2021
b2a1c95
backwards compatibility with older versions of pytorch
Oct 14, 2021
d8678fa
handle edge case where reduced tensor smaller than world size
Oct 14, 2021
a0faca0
moved event synchronization to allgather handle wait() call
Oct 14, 2021
bf20c90
removed unnecessary barrier call
Oct 14, 2021
a353017
Merge branch 'master' into s3-pr
jfc4050 Oct 14, 2021
c51ba46
formatting fix after resolving merge conflict
Oct 14, 2021
ff01f5c
skip nvme prefetch when trace not complete
Oct 14, 2021
13093eb
opportunistically avoid memory allocation in allgather coalesced wher…
Oct 15, 2021
3cdcbdf
Merge branch 'master' into s3-pr
tjruwase Oct 20, 2021
64d74d1
Merge branch 'master' into s3-pr
tjruwase Oct 21, 2021
e30e6cc
Merge branch 'master' into s3-pr
tjruwase Oct 22, 2021
f19593d
fix indentation after merge
Oct 22, 2021
f72bc78
fixes to account for parameter offload
Oct 22, 2021
660df05
accounting for torch.cuda.memory_stats not being available
Oct 22, 2021
4f9477f
moved partition_all_params to optimizer step
Oct 22, 2021
818651c
Merge branch 'master' into s3-pr
jeffra Oct 26, 2021
f681201
Merge branch 'master' into s3-pr
jfc4050 Oct 26, 2021
bb34f90
allgathering on params before item gets called
Oct 25, 2021
9f3b504
fix param status checks
Oct 25, 2021
1772d41
fix grad accumulation with optimizer offload
Oct 25, 2021
5f213d8
grad norm computation fix for optimizer offload
Oct 26, 2021
3198805
change post divide in reduce-scatter to pre divide
Oct 26, 2021
2225659
fix gradient race condition w/ optimizer offload
Oct 26, 2021
5aa9bd5
improve inf/nan gradient tracking
Oct 26, 2021
a1a60ed
don't prefetch when not in training mode
Oct 26, 2021
df41659
format fix after merging
Oct 26, 2021
ab3a82a
fix prefetching issue when using NVME offload
Oct 27, 2021
025a41e
Merge branch 'master' into s3-pr
tjruwase Oct 29, 2021
6f9415b
Merge branch 'master' into s3-pr
jfc4050 Nov 1, 2021
8d12281
Merge branch 'master' into s3-pr
jfc4050 Nov 2, 2021
a26d1fb
improved defragmentation for fp16 parameters
Oct 31, 2021
937f04e
relative imports for bf16 tests
Nov 2, 2021
e74f509
changes for bwd compatibility with pytorch 1.2
Nov 2, 2021
6ee558d
remove buffered_reduce_fallback
Nov 2, 2021
14e22a2
removed unused parameter offset bookkeeping
Nov 3, 2021
16281df
fixed tracking for multiple param groups
Nov 3, 2021
38af6b1
Merge branch 'master' into s3-pr
tjruwase Nov 3, 2021
cc7011e
unbroke bfloat16 config after merge conflict
Nov 3, 2021
806b072
using base allgather params when only 1 param
Nov 3, 2021
bf0dd66
cleanup/fixes for fp16 partition defragmentation
Nov 3, 2021
73207ae
Merge branch 'master' into s3-pr
tjruwase Nov 5, 2021
d3ecb1f
Merge branch 'master' into s3-pr
tjruwase Nov 5, 2021
812fe67
Merge branch 'master' into s3-pr
tjruwase Nov 11, 2021
6dc21a6
switch to CRLF
jeffra Nov 18, 2021
2a38302
convert to same new-line style as master
jeffra Nov 18, 2021
16f1d21
align new line with master
jeffra Nov 18, 2021
11d590a
Merge branch 'master' into s3-pr
tjruwase Nov 23, 2021
2b5f6ea
Fix merge issues
tjruwase Nov 23, 2021
80b53d3
Merge branch 'master' into s3-pr
tjruwase Nov 24, 2021
6dfe693
Merge branch 'master' into s3-pr
tjruwase Nov 24, 2021
912e6f0
switch to CRLF
jeffra Nov 29, 2021
4b0133b
fix to LF line endings
jeffra Nov 30, 2021
b998206
minor merge fixes
jeffra Nov 30, 2021
d6deecb
remove extra bfloat16_enabled definition
Nov 30, 2021
2a4ef29
asserting params inflight for AllGatherHandle
Nov 30, 2021
90182b6
remove get_cuda_mem_allocated_str
Nov 30, 2021
ad847ed
Merge branch 'master' into s3-pr
tjruwase Dec 8, 2021
f590ba4
Format fixes
tjruwase Dec 8, 2021
9db815f
fix bfloat16 zero stage check (broken after merge commit)
Dec 8, 2021
259ec15
+self.communication_data_type, -self.allreduce_always_fp32; delete de…
tjruwase Dec 8, 2021
96d2247
Add self.reduce_scatter
tjruwase Dec 9, 2021
2630b75
Merge branch 'master' into s3-pr
tjruwase Dec 9, 2021
79fd42c
Merge branch 'master' into s3-pr
tjruwase Dec 11, 2021
8565e04
Merge branch 'master' into s3-pr
jeffra Dec 14, 2021
06eab1a
Merge branch 'master' into s3-pr
tjruwase Dec 30, 2021
0f8affe
Format fix
tjruwase Dec 30, 2021
3436422
Merge branch 'master' into s3-pr
tjruwase Dec 30, 2021
601d1f1
Fix merge issues
tjruwase Dec 30, 2021
5dcee36
Merge branch 's3-pr' of github.com:jfc4050/DeepSpeed into s3-pr
tjruwase Dec 30, 2021
580d25e
Merge branch 'master' into s3-pr
tjruwase Jan 3, 2022
872f451
Merge branch 'master' into s3-pr
jeffra Jan 7, 2022
e236293
Merge branch 'master' into s3-pr
tjruwase Jan 10, 2022
43b3b83
Merge branch 'master' into s3-pr
tjruwase Jan 11, 2022
83905ac
Merge branch 'master' into s3-pr
tjruwase Jan 12, 2022
31aecfc
iterate over params_to_fetch rather than make another iterator
Jan 12, 2022
8736700
add some TODOs
Jan 14, 2022
516379d
Merge branch 'master' into s3-pr
tjruwase Jan 14, 2022
0bf7bcd
remove unnecessary division by micro_step_id
Jan 19, 2022
43c00ff
rename config keys "bfloat16" -> "bf16"
Jan 19, 2022
4574bc7
rename stage3_gather_fp16_weights_on_model_save -> stage3_gather_16bi…
Jan 19, 2022
e04dc6a
add unit test to check backwards compatibility for gather_16bit_weights
Jan 19, 2022
391cecf
added test to confirm bf16 key bwd compatibility
Jan 19, 2022
3d26469
Merge branch 'master' into s3-pr
tjruwase Jan 19, 2022
536d171
Format fixes
tjruwase Jan 19, 2022
19f3538
Merge branch 'master' into s3-pr
tjruwase Jan 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions deepspeed/runtime/comm/coalesced_collectives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""batched collective operations for overhead amortization and better
bandwidth utilization"""

import math
from typing import List

import torch
from torch import Tensor
import torch.distributed
from torch.distributed import ProcessGroup
import torch.nn.functional

from deepspeed.utils import instrument_w_nvtx
from deepspeed.utils.logging import logger

if hasattr(torch.distributed, "_reduce_scatter_base"):

def torch_reduce_scatter_fn(input_tensor: Tensor, output_tensor: Tensor, group):
instrument_w_nvtx(torch.distributed._reduce_scatter_base)(
output_tensor,
input_tensor,
group=group,
)
else:
logger.warning(
"unable to find torch.distributed._reduce_scatter_base. will fall back to "
"torch.distributed.reduce_scatter which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")

def torch_reduce_scatter_fn(input_tensor: Tensor, output_tensor: Tensor, group):
input_tensor_lst = list(
torch.chunk(input_tensor,
torch.distributed.get_world_size(group)))
instrument_w_nvtx(torch.distributed.reduce_scatter)(
output_tensor,
input_tensor_lst,
group=group,
)


@instrument_w_nvtx
@torch.no_grad()
def reduce_scatter_coalesced(
tensors: List[Tensor],
group: ProcessGroup = None,
) -> List[Tensor]:
"""simultaneously reduce-scatter a list of tensors - this can be done more
efficiently than individual reduce scatter calls

TODO. see if PyTorch team wants a c++ verson of this for ProcessGroupNCCL
"""
this_rank = torch.distributed.get_rank(group)
world_sz = torch.distributed.get_world_size(group)

partition_lst_for_each_tensor = [None] * len(tensors)
for tensor_idx, tensor in enumerate(tensors):
flattened_tensor = tensor.view(-1)
chunk_sz = math.ceil(tensor.numel() / world_sz)
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
partition_lst_for_each_tensor[tensor_idx] = [
flattened_tensor[rank * chunk_sz:rank * chunk_sz + chunk_sz]
for rank in range(0,
world_sz)
]

padded_partition_sz_for_each_tensor = tuple(
math.ceil(t.numel() / world_sz) for t in tensors)

if len(tensors) == 1 and tensors[0].numel() % world_sz == 0:
# if there's only one tensor being reduced and we don't need to pad
# we have an opportunity to avoid a memory allocation
tensor_partition_flat_buffer = tensors[0].view(-1)
else:
# interleave tensor partitions such that the correct reduced partitions of each tensor
# end up at each rank
tensor_partitions_lst_with_padding = []
for rank in range(world_sz):
for tensor_idx in range(len(tensors)):
# add tensor content
tensor_chunk = partition_lst_for_each_tensor[tensor_idx][rank]
tensor_partitions_lst_with_padding.append(tensor_chunk)

# add padding if necessary
padding_sz = padded_partition_sz_for_each_tensor[
tensor_idx] - tensor_chunk.numel()
if padding_sz > 0:
tensor_partitions_lst_with_padding.append(
torch.empty(padding_sz,
dtype=tensor_chunk.dtype,
device=tensor_chunk.device))

tensor_partition_flat_buffer = instrument_w_nvtx(
torch.cat)(tensor_partitions_lst_with_padding)

tensor_partition_flat_buffer.div_(world_sz) # pre-divide
tensor_partition_buffer_for_each_rank: List[Tensor] = torch.chunk(
tensor_partition_flat_buffer,
world_sz)

# batched reduce-scatter call
torch_reduce_scatter_fn(tensor_partition_flat_buffer,
tensor_partition_buffer_for_each_rank[this_rank],
group)

# reverse procedure of the interleaving done previously, done on the
# result of the batched reduce-scatter
output_lst: List[Tensor] = [None] * len(tensors)
offset = 0
for tensor_idx in range(len(tensors)):
output_lst[tensor_idx] = tensor_partition_buffer_for_each_rank[this_rank].narrow(
0,
offset,
partition_lst_for_each_tensor[tensor_idx][this_rank].numel())

offset += padded_partition_sz_for_each_tensor[tensor_idx]

return output_lst
Loading