Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various ZeRO Stage3 Optimizations + Improvements (including bfloat16 support) #1453

Merged
merged 91 commits into from
Jan 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
fe26423
Changes for bfloat16 Zero2
raamjad Aug 14, 2021
8864f91
ZeRO stage3 optimizations, with some bug fixes
Sep 29, 2021
e66aedc
fix import in ut
Oct 12, 2021
350a7a0
ran yapf
Oct 12, 2021
b37a4f0
Merge branch 'master' into s3-pr
tjruwase Oct 13, 2021
f383947
improvements to cache flush warn log
Oct 13, 2021
b2a1c95
backwards compatibility with older versions of pytorch
Oct 14, 2021
d8678fa
handle edge case where reduced tensor smaller than world size
Oct 14, 2021
a0faca0
moved event synchronization to allgather handle wait() call
Oct 14, 2021
bf20c90
removed unnecessary barrier call
Oct 14, 2021
a353017
Merge branch 'master' into s3-pr
jfc4050 Oct 14, 2021
c51ba46
formatting fix after resolving merge conflict
Oct 14, 2021
ff01f5c
skip nvme prefetch when trace not complete
Oct 14, 2021
13093eb
opportunistically avoid memory allocation in allgather coalesced wher…
Oct 15, 2021
3cdcbdf
Merge branch 'master' into s3-pr
tjruwase Oct 20, 2021
64d74d1
Merge branch 'master' into s3-pr
tjruwase Oct 21, 2021
e30e6cc
Merge branch 'master' into s3-pr
tjruwase Oct 22, 2021
f19593d
fix indentation after merge
Oct 22, 2021
f72bc78
fixes to account for parameter offload
Oct 22, 2021
660df05
accounting for torch.cuda.memory_stats not being available
Oct 22, 2021
4f9477f
moved partition_all_params to optimizer step
Oct 22, 2021
818651c
Merge branch 'master' into s3-pr
jeffra Oct 26, 2021
f681201
Merge branch 'master' into s3-pr
jfc4050 Oct 26, 2021
bb34f90
allgathering on params before item gets called
Oct 25, 2021
9f3b504
fix param status checks
Oct 25, 2021
1772d41
fix grad accumulation with optimizer offload
Oct 25, 2021
5f213d8
grad norm computation fix for optimizer offload
Oct 26, 2021
3198805
change post divide in reduce-scatter to pre divide
Oct 26, 2021
2225659
fix gradient race condition w/ optimizer offload
Oct 26, 2021
5aa9bd5
improve inf/nan gradient tracking
Oct 26, 2021
a1a60ed
don't prefetch when not in training mode
Oct 26, 2021
df41659
format fix after merging
Oct 26, 2021
ab3a82a
fix prefetching issue when using NVME offload
Oct 27, 2021
025a41e
Merge branch 'master' into s3-pr
tjruwase Oct 29, 2021
6f9415b
Merge branch 'master' into s3-pr
jfc4050 Nov 1, 2021
8d12281
Merge branch 'master' into s3-pr
jfc4050 Nov 2, 2021
a26d1fb
improved defragmentation for fp16 parameters
Oct 31, 2021
937f04e
relative imports for bf16 tests
Nov 2, 2021
e74f509
changes for bwd compatibility with pytorch 1.2
Nov 2, 2021
6ee558d
remove buffered_reduce_fallback
Nov 2, 2021
14e22a2
removed unused parameter offset bookkeeping
Nov 3, 2021
16281df
fixed tracking for multiple param groups
Nov 3, 2021
38af6b1
Merge branch 'master' into s3-pr
tjruwase Nov 3, 2021
cc7011e
unbroke bfloat16 config after merge conflict
Nov 3, 2021
806b072
using base allgather params when only 1 param
Nov 3, 2021
bf0dd66
cleanup/fixes for fp16 partition defragmentation
Nov 3, 2021
73207ae
Merge branch 'master' into s3-pr
tjruwase Nov 5, 2021
d3ecb1f
Merge branch 'master' into s3-pr
tjruwase Nov 5, 2021
812fe67
Merge branch 'master' into s3-pr
tjruwase Nov 11, 2021
6dc21a6
switch to CRLF
jeffra Nov 18, 2021
2a38302
convert to same new-line style as master
jeffra Nov 18, 2021
16f1d21
align new line with master
jeffra Nov 18, 2021
11d590a
Merge branch 'master' into s3-pr
tjruwase Nov 23, 2021
2b5f6ea
Fix merge issues
tjruwase Nov 23, 2021
80b53d3
Merge branch 'master' into s3-pr
tjruwase Nov 24, 2021
6dfe693
Merge branch 'master' into s3-pr
tjruwase Nov 24, 2021
912e6f0
switch to CRLF
jeffra Nov 29, 2021
4b0133b
fix to LF line endings
jeffra Nov 30, 2021
b998206
minor merge fixes
jeffra Nov 30, 2021
d6deecb
remove extra bfloat16_enabled definition
Nov 30, 2021
2a4ef29
asserting params inflight for AllGatherHandle
Nov 30, 2021
90182b6
remove get_cuda_mem_allocated_str
Nov 30, 2021
ad847ed
Merge branch 'master' into s3-pr
tjruwase Dec 8, 2021
f590ba4
Format fixes
tjruwase Dec 8, 2021
9db815f
fix bfloat16 zero stage check (broken after merge commit)
Dec 8, 2021
259ec15
+self.communication_data_type, -self.allreduce_always_fp32; delete de…
tjruwase Dec 8, 2021
96d2247
Add self.reduce_scatter
tjruwase Dec 9, 2021
2630b75
Merge branch 'master' into s3-pr
tjruwase Dec 9, 2021
79fd42c
Merge branch 'master' into s3-pr
tjruwase Dec 11, 2021
8565e04
Merge branch 'master' into s3-pr
jeffra Dec 14, 2021
06eab1a
Merge branch 'master' into s3-pr
tjruwase Dec 30, 2021
0f8affe
Format fix
tjruwase Dec 30, 2021
3436422
Merge branch 'master' into s3-pr
tjruwase Dec 30, 2021
601d1f1
Fix merge issues
tjruwase Dec 30, 2021
5dcee36
Merge branch 's3-pr' of github.com:jfc4050/DeepSpeed into s3-pr
tjruwase Dec 30, 2021
580d25e
Merge branch 'master' into s3-pr
tjruwase Jan 3, 2022
872f451
Merge branch 'master' into s3-pr
jeffra Jan 7, 2022
e236293
Merge branch 'master' into s3-pr
tjruwase Jan 10, 2022
43b3b83
Merge branch 'master' into s3-pr
tjruwase Jan 11, 2022
83905ac
Merge branch 'master' into s3-pr
tjruwase Jan 12, 2022
31aecfc
iterate over params_to_fetch rather than make another iterator
Jan 12, 2022
8736700
add some TODOs
Jan 14, 2022
516379d
Merge branch 'master' into s3-pr
tjruwase Jan 14, 2022
0bf7bcd
remove unnecessary division by micro_step_id
Jan 19, 2022
43c00ff
rename config keys "bfloat16" -> "bf16"
Jan 19, 2022
4574bc7
rename stage3_gather_fp16_weights_on_model_save -> stage3_gather_16bi…
Jan 19, 2022
e04dc6a
add unit test to check backwards compatibility for gather_16bit_weights
Jan 19, 2022
391cecf
added test to confirm bf16 key bwd compatibility
Jan 19, 2022
3d26469
Merge branch 'master' into s3-pr
tjruwase Jan 19, 2022
536d171
Format fixes
tjruwase Jan 19, 2022
19f3538
Merge branch 'master' into s3-pr
tjruwase Jan 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DeepSpeedExamples
Submodule DeepSpeedExamples updated 50 files
+0 −300 HelloDeepSpeed/README.md
+0 −8 HelloDeepSpeed/requirements.txt
+0 −0 HelloDeepSpeed/tests/__init__.py
+0 −108 HelloDeepSpeed/tests/test_train_bert.py
+0 −791 HelloDeepSpeed/train_bert.py
+0 −805 HelloDeepSpeed/train_bert_ds.py
+0 −1 Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/README.md
+3 −3 Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/ds_pretrain_gpt2.sh
+7 −4 Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/ds_train.sh
+0 −2 Megatron-LM-v1.1.5-ZeRO3/megatron/arguments.py
+1 −2 Megatron-LM-v1.1.5-ZeRO3/megatron/checkpointing.py
+15 −38 Megatron-LM-v1.1.5-ZeRO3/megatron/learning_rates.py
+1 −5 Megatron-LM-v1.1.5-ZeRO3/megatron/model/gpt2_model.py
+1 −1 Megatron-LM-v1.1.5-ZeRO3/megatron/training.py
+2 −2 Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py
+1 −1 MoQ/huggingface-transformers/tests/fixtures/tests_samples/GermEval/dev.txt
+0 −2 README.md
+0 −4 autotuning/.gitignore
+0 −3 autotuning/README.md
+0 −62 autotuning/hf/README.md
+0 −58 autotuning/hf/bert-base/README.md
+0 −12 autotuning/hf/bert-base/ds_config_tune.json
+0 −114 autotuning/hf/bert-base/test_tune.sh
+0 −55 autotuning/hf/bert-large/README.md
+0 −11 autotuning/hf/bert-large/ds_config_tune.json
+0 −114 autotuning/hf/bert-large/test_tune.sh
+0 −72 autotuning/hf/deberta/README.md
+0 −16 autotuning/hf/deberta/ds_config_fp16_tune.json
+0 −127 autotuning/hf/deberta/test_tune.sh
+0 −69 autotuning/hf/distilbert/README.md
+0 −12 autotuning/hf/distilbert/ds_config_tune.json
+0 −119 autotuning/hf/distilbert/test_tune.sh
+0 −15 autotuning/hf/dsconfigs/ds_config_fp16_tune.json
+0 −9 autotuning/hf/dsconfigs/ds_config_fp16_z0.json
+0 −9 autotuning/hf/dsconfigs/ds_config_fp16_z1.json
+0 −9 autotuning/hf/dsconfigs/ds_config_fp16_z2.json
+0 −9 autotuning/hf/dsconfigs/ds_config_fp16_z3.json
+0 −12 autotuning/hf/dsconfigs/ds_config_tune.json
+0 −6 autotuning/hf/dsconfigs/ds_config_z0.json
+0 −6 autotuning/hf/dsconfigs/ds_config_z1.json
+0 −6 autotuning/hf/dsconfigs/ds_config_z2.json
+0 −6 autotuning/hf/dsconfigs/ds_config_z3.json
+0 −59 autotuning/hf/gpt2-large/README.md
+0 −132 autotuning/hf/gpt2-large/test_tune.sh
+0 −57 autotuning/hf/gpt2-medium/README.md
+0 −142 autotuning/hf/gpt2-medium/test_tune.sh
+0 −56 autotuning/hf/gpt2-xl/README.md
+0 −142 autotuning/hf/gpt2-xl/test_tune.sh
+0 −59 autotuning/hf/gpt2/README.md
+0 −133 autotuning/hf/gpt2/test_tune.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"stage3_max_reuse_distance": 1e9,
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e6,
"stage3_gather_fp16_weights_on_model_save": false,
"stage3_gather_16bit_weights_on_model_save": false,
"sub_group_size": 1e12
}
}
116 changes: 116 additions & 0 deletions deepspeed/runtime/comm/coalesced_collectives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""batched collective operations for overhead amortization and better
bandwidth utilization"""

import math
from typing import List

import torch
from torch import Tensor
import torch.distributed
from torch.distributed import ProcessGroup
import torch.nn.functional

from deepspeed.utils import instrument_w_nvtx
from deepspeed.utils.logging import logger

if hasattr(torch.distributed, "_reduce_scatter_base"):

def torch_reduce_scatter_fn(input_tensor: Tensor, output_tensor: Tensor, group):
instrument_w_nvtx(torch.distributed._reduce_scatter_base)(
output_tensor,
input_tensor,
group=group,
)
else:
logger.warning(
"unable to find torch.distributed._reduce_scatter_base. will fall back to "
"torch.distributed.reduce_scatter which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")

def torch_reduce_scatter_fn(input_tensor: Tensor, output_tensor: Tensor, group):
input_tensor_lst = list(
torch.chunk(input_tensor,
torch.distributed.get_world_size(group)))
instrument_w_nvtx(torch.distributed.reduce_scatter)(
output_tensor,
input_tensor_lst,
group=group,
)


@instrument_w_nvtx
@torch.no_grad()
def reduce_scatter_coalesced(
tensors: List[Tensor],
group: ProcessGroup = None,
) -> List[Tensor]:
"""simultaneously reduce-scatter a list of tensors - this can be done more
efficiently than individual reduce scatter calls

TODO. see if PyTorch team wants a c++ verson of this for ProcessGroupNCCL
"""
this_rank = torch.distributed.get_rank(group)
world_sz = torch.distributed.get_world_size(group)

partition_lst_for_each_tensor = [None] * len(tensors)
for tensor_idx, tensor in enumerate(tensors):
flattened_tensor = tensor.view(-1)
chunk_sz = math.ceil(tensor.numel() / world_sz)
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
partition_lst_for_each_tensor[tensor_idx] = [
flattened_tensor[rank * chunk_sz:rank * chunk_sz + chunk_sz]
for rank in range(0,
world_sz)
]

padded_partition_sz_for_each_tensor = tuple(
math.ceil(t.numel() / world_sz) for t in tensors)

if len(tensors) == 1 and tensors[0].numel() % world_sz == 0:
# if there's only one tensor being reduced and we don't need to pad
# we have an opportunity to avoid a memory allocation
tensor_partition_flat_buffer = tensors[0].view(-1)
else:
# interleave tensor partitions such that the correct reduced partitions of each tensor
# end up at each rank
tensor_partitions_lst_with_padding = []
for rank in range(world_sz):
for tensor_idx in range(len(tensors)):
# add tensor content
tensor_chunk = partition_lst_for_each_tensor[tensor_idx][rank]
tensor_partitions_lst_with_padding.append(tensor_chunk)

# add padding if necessary
padding_sz = padded_partition_sz_for_each_tensor[
tensor_idx] - tensor_chunk.numel()
if padding_sz > 0:
tensor_partitions_lst_with_padding.append(
torch.empty(padding_sz,
dtype=tensor_chunk.dtype,
device=tensor_chunk.device))

tensor_partition_flat_buffer = instrument_w_nvtx(
torch.cat)(tensor_partitions_lst_with_padding)

tensor_partition_flat_buffer.div_(world_sz) # pre-divide
tensor_partition_buffer_for_each_rank: List[Tensor] = torch.chunk(
tensor_partition_flat_buffer,
world_sz)

# batched reduce-scatter call
torch_reduce_scatter_fn(tensor_partition_flat_buffer,
tensor_partition_buffer_for_each_rank[this_rank],
group)

# reverse procedure of the interleaving done previously, done on the
# result of the batched reduce-scatter
output_lst: List[Tensor] = [None] * len(tensors)
offset = 0
for tensor_idx in range(len(tensors)):
output_lst[tensor_idx] = tensor_partition_buffer_for_each_rank[this_rank].narrow(
0,
offset,
partition_lst_for_each_tensor[tensor_idx][this_rank].numel())

offset += padded_partition_sz_for_each_tensor[tensor_idx]

return output_lst
14 changes: 7 additions & 7 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,12 @@ def get_fp16_enabled(param_dict):


def get_bfloat16_enabled(param_dict):
if BFLOAT16 in param_dict.keys():
return get_scalar_param(param_dict[BFLOAT16],
BFLOAT16_ENABLED,
BFLOAT16_ENABLED_DEFAULT)
else:
return False
for key in [BFLOAT16, BFLOAT16_OLD]:
if key in param_dict.keys():
return get_scalar_param(param_dict[key],
BFLOAT16_ENABLED,
BFLOAT16_ENABLED_DEFAULT)
return False


def get_fp16_master_weights_and_grads_enabled(param_dict):
Expand Down Expand Up @@ -899,7 +899,7 @@ def _initialize_params(self, param_dict):
self.fp16_enabled = get_fp16_enabled(param_dict)
self.bfloat16_enabled = get_bfloat16_enabled(param_dict)
assert not (self.fp16_enabled and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled'
assert not (self.bfloat16_enabled and (self.zero_optimization_stage != 2)), 'bfloat16 mode is only enabled for Zero2 currently'
assert not (self.bfloat16_enabled and (self.zero_optimization_stage not in {2, 3})), f'bfloat16 mode is only enabled for Zero 2 and 3 currently. got {self.zero_optimization_stage}'
self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled(
param_dict)
self.amp_enabled = get_amp_enabled(param_dict)
Expand Down
5 changes: 3 additions & 2 deletions deepspeed/runtime/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,12 @@
# Users can configure in ds_config.json as below example:
BFLOAT16_FORMAT = '''
BFLOAT16 parameters should be of the format:
"bfloat16": {
"bf16": {
"enabled": true
}
'''
BFLOAT16 = "bfloat16"
BFLOAT16 = "bf16"
BFLOAT16_OLD = "bfloat16" # keeping for backwards compatibility

BFLOAT16_ENABLED = "enabled"
BFLOAT16_ENABLED_DEFAULT = False
Expand Down
52 changes: 33 additions & 19 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
import deepspeed.runtime.lr_schedules as lr_schedules
import deepspeed.utils.groups as groups
from deepspeed.runtime.utils import get_grad_norm
from deepspeed.utils import logger, log_dist, init_distributed
from deepspeed.utils import logger, log_dist, init_distributed, instrument_w_nvtx
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
from deepspeed.utils.debug import debug_extract_module_and_param_names
from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop
Expand Down Expand Up @@ -706,8 +706,8 @@ def zero_prefetch_bucket_size(self):
def zero_param_persistence_threshold(self):
return self._config.zero_config.param_persistence_threshold

def zero_gather_fp16_weights_on_model_save(self):
return self._config.zero_config.gather_fp16_weights_on_model_save
def zero_gather_16bit_weights_on_model_save(self):
return self._config.zero_config.gather_16bit_weights_on_model_save

def zero_grad_hooks(self):
return self._config.zero_config.grad_hooks
Expand Down Expand Up @@ -969,6 +969,16 @@ def is_replicated(p):
self.broadcast_src_rank,
group=self.data_parallel_group)

@staticmethod
def __check_params(model: Module, dtype: torch.dtype) -> None:
if not all(param.dtype == dtype
for param in model.parameters()) and dist.get_rank() == 0:
raise ValueError(
f"{dtype} is enabled but the following parameters have dtype that is "
f"not {dtype}: "
f"{[(n, p.dtype) for n, p in model.named_parameters() if p.dtype != dtype]}"
)

def _configure_distributed_model(self, model):
self.module = model
if self.fp16_enabled():
Expand All @@ -986,17 +996,13 @@ def _configure_distributed_model(self, model):
)
self.module.half()
elif self.bfloat16_enabled():
if self.zero_optimization_partition_weights() and any(
hasattr(param,
'ds_id') for param in self.module.parameters()):
self.__check_params(self.module, torch.bfloat16)
self.module.bfloat16()
else:
if not all(
[param.dtype == torch.float for param in self.module.parameters()]):
names = [
n for n,
p in self.module.named_parameters() if p.dtype != torch.float
]
raise ValueError(
f"fp32 is enabled but the following parameters have dtype that is not fp32: {', '.join(names)}"
)
self.__check_params(self.module, torch.float)

if not self.dont_change_device:
self.module.to(self.device)
Expand Down Expand Up @@ -1542,6 +1548,7 @@ def _scale_loss_by_gas(self, prescaled_loss):

return scaled_loss

@instrument_w_nvtx
def forward(self, *inputs, **kwargs):
r"""Execute forward propagation
Arguments:
Expand Down Expand Up @@ -1637,6 +1644,7 @@ def print_forward_breakdown(self, fwd_time):
f"rank={torch.distributed.get_rank()} time (ms) | forward: {fwd_time:.2f} (forward_moe: {moe_time:.2f}, 1st alltoall: {falltoall:.2f}, 2nd alltoall: {salltoall:.2f}, top-k: {gate_time:.2f})",
ranks=[0])

@instrument_w_nvtx
def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
# Pass (PP) gas boundary flag to optimizer (required for zero)
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary(
Expand All @@ -1654,6 +1662,7 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
else:
self.buffered_allreduce_fallback(elements_per_buffer=bucket_size)

@instrument_w_nvtx
def backward(self, loss, allreduce_gradients=True, release_loss=False):
r"""Execute backward pass on the loss

Expand Down Expand Up @@ -3013,7 +3022,7 @@ def _save_zero_checkpoint(self, save_path, tag):
self._copy_recovery_script(save_path)
logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name))

def _zero3_consolidated_fp16_state_dict(self):
def _zero3_consolidated_16bit_state_dict(self):
"""

Get a full non-partitioned state_dict with fp16 weights on cpu.
Expand Down Expand Up @@ -3082,17 +3091,22 @@ def get_layer_state_dict(module, prefix=""):
return state_dict

def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"):
r"""Save fp16 model weights
"""has been renamed to save_16bit_model, keeping this around for backwards
compatibility"""
return self.save_16bit_model(save_dir, save_filename)

def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin"):
r"""Save 16bit model weights

This method saves the fp16 model weights at the desired destination.
This method saves the 16bit model weights at the desired destination.

Arguments:
save_dir: Required. Directory for saving the model
save_filename: Optional. Filename to save to. Defaults to ``pytorch_model.bin``

Returns:
``True`` when a model has been saved, ``False`` otherwise. It will not be saved if
stage3_gather_fp16_weights_on_model_save is ``False``.
stage3_gather_16bit_weights_on_model_save is ``False``.

Important: all processes must call this method and not just the process with rank 0. It is
because the processes need to work in sync to gather the weights. This method will hang
Expand All @@ -3103,13 +3117,13 @@ def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"):
path = os.path.join(save_dir, save_filename)

if self.zero_optimization_partition_weights():
if self.zero_gather_fp16_weights_on_model_save():
if self.zero_gather_16bit_weights_on_model_save():
# consolidation is expensive in time and memory and therefore isn't a default
state_dict = self._zero3_consolidated_fp16_state_dict()
state_dict = self._zero3_consolidated_16bit_state_dict()
else:
# the model will be bogus if not consolidated so don't confuse the user by saving it
logger.info(
f"Did not save the model {path} because `stage3_gather_fp16_weights_on_model_save` is False"
f"Did not save the model {path} because `stage3_gather_16bit_weights_on_model_save` is False"
)
return False
else:
Expand Down
9 changes: 9 additions & 0 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,3 +858,12 @@ def call_to_str(base, *args, **kwargs):
name += ', '.join(f'{key}={repr(arg)}' for key, arg in kwargs.items())
name += ')'
return name


def get_only_unique_item(items):
item_set = set(items)
if len(item_set) != 1:
raise RuntimeError(f"expected there to be only one unique element in {items}")
unique_item, = item_set

return unique_item
16 changes: 11 additions & 5 deletions deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, param_dict):
self.param_persistence_threshold = None
self.max_live_parameters = None
self.max_reuse_distance = None
self.gather_fp16_weights_on_model_save = None
self.gather_16bit_weights_on_model_save = None

self.ignore_unused_parameters = None
self.round_robin_gradients = None
Expand Down Expand Up @@ -171,10 +171,16 @@ def _initialize(self, zero_config_dict):
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD,
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT)

self.gather_fp16_weights_on_model_save = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE,
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT)
# config key has been renamed to use "16bit" instead of "fp16." falling back
# to old config name in order to preserve backwards compatibility
self.gather_16bit_weights_on_model_save = ZERO_OPTIMIZATION_GATHER_16BIT_WEIGHTS_ON_MODEL_SAVE_DEFAULT
for key in [
ZERO_OPTIMIZATION_GATHER_16BIT_WEIGHTS_ON_MODEL_SAVE,
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE
]:
if key in zero_config_dict:
self.gather_16bit_weights_on_model_save = zero_config_dict[key]
break
jfc4050 marked this conversation as resolved.
Show resolved Hide resolved

self.ignore_unused_parameters = get_scalar_param(
zero_config_dict,
Expand Down
7 changes: 4 additions & 3 deletions deepspeed/runtime/zero/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@

# gathers params for saving a model - inefficient but is required in certain situations

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could someone kindly help explain which situations were required by this 16bit parameters gathering (infeeicient) feature, given that there is zero_to_fp32.py script which can help save the parameters? Thanks.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for those who can't be bothered with running zero_to_fp32.py and want the 16-bit model extracted on the fly - which is fine for tiny to small models but very slow for large models.

It's also the default in the HF Trainer integration of Deepspeed to make it easy for users to start and have things work transparently. But the documentation explains how to improve upon this default.
https://huggingface.co/docs/transformers/main/main_classes/deepspeed#getting-the-model-weights-out

ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE = 'stage3_gather_fp16_weights_on_model_save'
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT = False
ZERO_OPTIMIZATION_GATHER_16BIT_WEIGHTS_ON_MODEL_SAVE = 'stage3_gather_16bit_weights_on_model_save'
ZERO_OPTIMIZATION_GATHER_16BIT_WEIGHTS_ON_MODEL_SAVE_DEFAULT = False

# Now just used in stage2 complete_grad_norm_calculation_for_cpu_offload
# Enable this option to avoid:
Expand Down Expand Up @@ -164,8 +165,8 @@
ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD:
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT,
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE:
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT,
ZERO_OPTIMIZATION_GATHER_16BIT_WEIGHTS_ON_MODEL_SAVE:
ZERO_OPTIMIZATION_GATHER_16BIT_WEIGHTS_ON_MODEL_SAVE_DEFAULT,
ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS:
ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS_DEFAULT,
ZERO_OPTIMIZATION_LEGACY_STAGE1:
Expand Down
Loading