Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various ZeRO Stage3 Optimizations + Improvements (including bfloat16 support) #1453

Merged
merged 91 commits into from
Jan 21, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
fe26423
Changes for bfloat16 Zero2
raamjad Aug 14, 2021
8864f91
ZeRO stage3 optimizations, with some bug fixes
Sep 29, 2021
e66aedc
fix import in ut
Oct 12, 2021
350a7a0
ran yapf
Oct 12, 2021
b37a4f0
Merge branch 'master' into s3-pr
tjruwase Oct 13, 2021
f383947
improvements to cache flush warn log
Oct 13, 2021
b2a1c95
backwards compatibility with older versions of pytorch
Oct 14, 2021
d8678fa
handle edge case where reduced tensor smaller than world size
Oct 14, 2021
a0faca0
moved event synchronization to allgather handle wait() call
Oct 14, 2021
bf20c90
removed unnecessary barrier call
Oct 14, 2021
a353017
Merge branch 'master' into s3-pr
jfc4050 Oct 14, 2021
c51ba46
formatting fix after resolving merge conflict
Oct 14, 2021
ff01f5c
skip nvme prefetch when trace not complete
Oct 14, 2021
13093eb
opportunistically avoid memory allocation in allgather coalesced wher…
Oct 15, 2021
3cdcbdf
Merge branch 'master' into s3-pr
tjruwase Oct 20, 2021
64d74d1
Merge branch 'master' into s3-pr
tjruwase Oct 21, 2021
e30e6cc
Merge branch 'master' into s3-pr
tjruwase Oct 22, 2021
f19593d
fix indentation after merge
Oct 22, 2021
f72bc78
fixes to account for parameter offload
Oct 22, 2021
660df05
accounting for torch.cuda.memory_stats not being available
Oct 22, 2021
4f9477f
moved partition_all_params to optimizer step
Oct 22, 2021
818651c
Merge branch 'master' into s3-pr
jeffra Oct 26, 2021
f681201
Merge branch 'master' into s3-pr
jfc4050 Oct 26, 2021
bb34f90
allgathering on params before item gets called
Oct 25, 2021
9f3b504
fix param status checks
Oct 25, 2021
1772d41
fix grad accumulation with optimizer offload
Oct 25, 2021
5f213d8
grad norm computation fix for optimizer offload
Oct 26, 2021
3198805
change post divide in reduce-scatter to pre divide
Oct 26, 2021
2225659
fix gradient race condition w/ optimizer offload
Oct 26, 2021
5aa9bd5
improve inf/nan gradient tracking
Oct 26, 2021
a1a60ed
don't prefetch when not in training mode
Oct 26, 2021
df41659
format fix after merging
Oct 26, 2021
ab3a82a
fix prefetching issue when using NVME offload
Oct 27, 2021
025a41e
Merge branch 'master' into s3-pr
tjruwase Oct 29, 2021
6f9415b
Merge branch 'master' into s3-pr
jfc4050 Nov 1, 2021
8d12281
Merge branch 'master' into s3-pr
jfc4050 Nov 2, 2021
a26d1fb
improved defragmentation for fp16 parameters
Oct 31, 2021
937f04e
relative imports for bf16 tests
Nov 2, 2021
e74f509
changes for bwd compatibility with pytorch 1.2
Nov 2, 2021
6ee558d
remove buffered_reduce_fallback
Nov 2, 2021
14e22a2
removed unused parameter offset bookkeeping
Nov 3, 2021
16281df
fixed tracking for multiple param groups
Nov 3, 2021
38af6b1
Merge branch 'master' into s3-pr
tjruwase Nov 3, 2021
cc7011e
unbroke bfloat16 config after merge conflict
Nov 3, 2021
806b072
using base allgather params when only 1 param
Nov 3, 2021
bf0dd66
cleanup/fixes for fp16 partition defragmentation
Nov 3, 2021
73207ae
Merge branch 'master' into s3-pr
tjruwase Nov 5, 2021
d3ecb1f
Merge branch 'master' into s3-pr
tjruwase Nov 5, 2021
812fe67
Merge branch 'master' into s3-pr
tjruwase Nov 11, 2021
6dc21a6
switch to CRLF
jeffra Nov 18, 2021
2a38302
convert to same new-line style as master
jeffra Nov 18, 2021
16f1d21
align new line with master
jeffra Nov 18, 2021
11d590a
Merge branch 'master' into s3-pr
tjruwase Nov 23, 2021
2b5f6ea
Fix merge issues
tjruwase Nov 23, 2021
80b53d3
Merge branch 'master' into s3-pr
tjruwase Nov 24, 2021
6dfe693
Merge branch 'master' into s3-pr
tjruwase Nov 24, 2021
912e6f0
switch to CRLF
jeffra Nov 29, 2021
4b0133b
fix to LF line endings
jeffra Nov 30, 2021
b998206
minor merge fixes
jeffra Nov 30, 2021
d6deecb
remove extra bfloat16_enabled definition
Nov 30, 2021
2a4ef29
asserting params inflight for AllGatherHandle
Nov 30, 2021
90182b6
remove get_cuda_mem_allocated_str
Nov 30, 2021
ad847ed
Merge branch 'master' into s3-pr
tjruwase Dec 8, 2021
f590ba4
Format fixes
tjruwase Dec 8, 2021
9db815f
fix bfloat16 zero stage check (broken after merge commit)
Dec 8, 2021
259ec15
+self.communication_data_type, -self.allreduce_always_fp32; delete de…
tjruwase Dec 8, 2021
96d2247
Add self.reduce_scatter
tjruwase Dec 9, 2021
2630b75
Merge branch 'master' into s3-pr
tjruwase Dec 9, 2021
79fd42c
Merge branch 'master' into s3-pr
tjruwase Dec 11, 2021
8565e04
Merge branch 'master' into s3-pr
jeffra Dec 14, 2021
06eab1a
Merge branch 'master' into s3-pr
tjruwase Dec 30, 2021
0f8affe
Format fix
tjruwase Dec 30, 2021
3436422
Merge branch 'master' into s3-pr
tjruwase Dec 30, 2021
601d1f1
Fix merge issues
tjruwase Dec 30, 2021
5dcee36
Merge branch 's3-pr' of github.com:jfc4050/DeepSpeed into s3-pr
tjruwase Dec 30, 2021
580d25e
Merge branch 'master' into s3-pr
tjruwase Jan 3, 2022
872f451
Merge branch 'master' into s3-pr
jeffra Jan 7, 2022
e236293
Merge branch 'master' into s3-pr
tjruwase Jan 10, 2022
43b3b83
Merge branch 'master' into s3-pr
tjruwase Jan 11, 2022
83905ac
Merge branch 'master' into s3-pr
tjruwase Jan 12, 2022
31aecfc
iterate over params_to_fetch rather than make another iterator
Jan 12, 2022
8736700
add some TODOs
Jan 14, 2022
516379d
Merge branch 'master' into s3-pr
tjruwase Jan 14, 2022
0bf7bcd
remove unnecessary division by micro_step_id
Jan 19, 2022
43c00ff
rename config keys "bfloat16" -> "bf16"
Jan 19, 2022
4574bc7
rename stage3_gather_fp16_weights_on_model_save -> stage3_gather_16bi…
Jan 19, 2022
e04dc6a
add unit test to check backwards compatibility for gather_16bit_weights
Jan 19, 2022
391cecf
added test to confirm bf16 key bwd compatibility
Jan 19, 2022
3d26469
Merge branch 'master' into s3-pr
tjruwase Jan 19, 2022
536d171
Format fixes
tjruwase Jan 19, 2022
19f3538
Merge branch 'master' into s3-pr
tjruwase Jan 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepspeed/autotuning/config_templates/template_zero3.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"stage3_max_reuse_distance": 1e9,
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e6,
"stage3_gather_fp16_weights_on_model_save": false,
"stage3_gather_16bit_weights_on_model_save": false,
"sub_group_size": 1e12
}
}
23 changes: 14 additions & 9 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,8 +708,8 @@ def zero_prefetch_bucket_size(self):
def zero_param_persistence_threshold(self):
return self._config.zero_config.param_persistence_threshold

def zero_gather_fp16_weights_on_model_save(self):
return self._config.zero_config.gather_fp16_weights_on_model_save
def zero_gather_16bit_weights_on_model_save(self):
return self._config.zero_config.gather_16bit_weights_on_model_save

def zero_grad_hooks(self):
return self._config.zero_config.grad_hooks
Expand Down Expand Up @@ -2955,7 +2955,7 @@ def _save_zero_checkpoint(self, save_path, tag):
self._copy_recovery_script(save_path)
logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name))

def _zero3_consolidated_fp16_state_dict(self):
def _zero3_consolidated_16bit_state_dict(self):
"""

Get a full non-partitioned state_dict with fp16 weights on cpu.
Expand Down Expand Up @@ -3024,17 +3024,22 @@ def get_layer_state_dict(module, prefix=""):
return state_dict

def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"):
r"""Save fp16 model weights
"""has been renamed to save_16bit_model, keeping this around for backwards
compatibility"""
return self.save_16bit_model(save_dir, save_filename)

This method saves the fp16 model weights at the desired destination.
def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin"):
r"""Save 16bit model weights

This method saves the 16bit model weights at the desired destination.

Arguments:
save_dir: Required. Directory for saving the model
save_filename: Optional. Filename to save to. Defaults to ``pytorch_model.bin``

Returns:
``True`` when a model has been saved, ``False`` otherwise. It will not be saved if
stage3_gather_fp16_weights_on_model_save is ``False``.
stage3_gather_16bit_weights_on_model_save is ``False``.

Important: all processes must call this method and not just the process with rank 0. It is
because the processes need to work in sync to gather the weights. This method will hang
Expand All @@ -3045,13 +3050,13 @@ def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"):
path = os.path.join(save_dir, save_filename)

if self.zero_optimization_partition_weights():
if self.zero_gather_fp16_weights_on_model_save():
if self.zero_gather_16bit_weights_on_model_save():
# consolidation is expensive in time and memory and therefore isn't a default
state_dict = self._zero3_consolidated_fp16_state_dict()
state_dict = self._zero3_consolidated_16bit_state_dict()
else:
# the model will be bogus if not consolidated so don't confuse the user by saving it
logger.info(
f"Did not save the model {path} because `stage3_gather_fp16_weights_on_model_save` is False"
f"Did not save the model {path} because `stage3_gather_16bit_weights_on_model_save` is False"
)
return False
else:
Expand Down
16 changes: 11 additions & 5 deletions deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, param_dict):
self.param_persistence_threshold = None
self.max_live_parameters = None
self.max_reuse_distance = None
self.gather_fp16_weights_on_model_save = None
self.gather_16bit_weights_on_model_save = None

self.ignore_unused_parameters = None
self.round_robin_gradients = None
Expand Down Expand Up @@ -171,10 +171,16 @@ def _initialize(self, zero_config_dict):
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD,
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT)

self.gather_fp16_weights_on_model_save = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE,
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT)
# config key has been renamed to use "16bit" instead of "fp16." falling back
# to old config name in order to preserve backwards compatibility
self.gather_16bit_weights_on_model_save = ZERO_OPTIMIZATION_GATHER_16BIT_WEIGHTS_ON_MODEL_SAVE_DEFAULT
for key in [
ZERO_OPTIMIZATION_GATHER_16BIT_WEIGHTS_ON_MODEL_SAVE,
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE
]:
if key in zero_config_dict:
self.gather_16bit_weights_on_model_save = zero_config_dict[key]
break
jfc4050 marked this conversation as resolved.
Show resolved Hide resolved

self.ignore_unused_parameters = get_scalar_param(
zero_config_dict,
Expand Down
7 changes: 4 additions & 3 deletions deepspeed/runtime/zero/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@

# gathers params for saving a model - inefficient but is required in certain situations

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could someone kindly help explain which situations were required by this 16bit parameters gathering (infeeicient) feature, given that there is zero_to_fp32.py script which can help save the parameters? Thanks.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for those who can't be bothered with running zero_to_fp32.py and want the 16-bit model extracted on the fly - which is fine for tiny to small models but very slow for large models.

It's also the default in the HF Trainer integration of Deepspeed to make it easy for users to start and have things work transparently. But the documentation explains how to improve upon this default.
https://huggingface.co/docs/transformers/main/main_classes/deepspeed#getting-the-model-weights-out

ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE = 'stage3_gather_fp16_weights_on_model_save'
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT = False
ZERO_OPTIMIZATION_GATHER_16BIT_WEIGHTS_ON_MODEL_SAVE = 'stage3_gather_16bit_weights_on_model_save'
ZERO_OPTIMIZATION_GATHER_16BIT_WEIGHTS_ON_MODEL_SAVE_DEFAULT = False

# Now just used in stage2 complete_grad_norm_calculation_for_cpu_offload
# Enable this option to avoid:
Expand Down Expand Up @@ -161,8 +162,8 @@
ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD:
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT,
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE:
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT,
ZERO_OPTIMIZATION_GATHER_16BIT_WEIGHTS_ON_MODEL_SAVE:
ZERO_OPTIMIZATION_GATHER_16BIT_WEIGHTS_ON_MODEL_SAVE_DEFAULT,
ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS:
ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS_DEFAULT,
ZERO_OPTIMIZATION_LEGACY_STAGE1:
Expand Down
6 changes: 3 additions & 3 deletions docs/_pages/config-json.md
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ Enabling and configuring ZeRO memory optimizations
"stage3_param_persistence_threshold" : 1e6,
"sub_group_size" : 1e12,
"elastic_checkpoint" : [true|false],
"stage3_gather_fp16_weights_on_model_save": [true|false],
"stage3_gather_16bit_weights_on_model_save": [true|false],
"ignore_unused_parameters": [true|false]
"round_robin_gradients": [true|false]
}
Expand Down Expand Up @@ -433,11 +433,11 @@ Enabling and configuring ZeRO memory optimizations
| Do not partition parameters smaller than this threshold. Smaller values use less memory, but can greatly increase communication (especially latency-bound messages). | `1e6` |


***stage3_gather_fp16_weights_on_model_save***: [boolean]
***stage3_gather_16bit_weights_on_model_save***: [boolean]

| Description | Default |
|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| ------- |
| Consolidate the weights before saving the model by `save_fp16_model()`. Since the weights are partitioned across GPUs, they aren't part of `state_dict`, so this function automatically gathers the weights when this option is enabled and then saves the fp16 model weights. | `False` |
| Consolidate the weights before saving the model by `save_16bit_model()`. Since the weights are partitioned across GPUs, they aren't part of `state_dict`, so this function automatically gathers the weights when this option is enabled and then saves the fp16 model weights. | `False` |


***cpu_offload***: [boolean]
Expand Down
6 changes: 3 additions & 3 deletions docs/_tutorials/zero.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,19 +252,19 @@ If you need to take the pretrained weights out of Deepspeed here is what you can

```json
"zero_optimization": {
"stage3_gather_fp16_weights_on_model_save": true
"stage3_gather_16bit_weights_on_model_save": true
},
```
And then save the model using:

```python
if self.deepspeed:
self.deepspeed.save_fp16_model(output_dir, output_file)
self.deepspeed.save_16bit_model(output_dir, output_file)
```

Because it requires consolidation of the weights on one GPU it can be slow and memory demanding, so only use this feature when needed.

Note that if `stage3_gather_fp16_weights_on_model_save` is `False`, no weights will be saved (again, because `state_dict` doesn't have them).
Note that if `stage3_gather_16bit_weights_on_model_save` is `False`, no weights will be saved (again, because `state_dict` doesn't have them).
You can use this method to save ZeRO-2 weights as well.

If you'd like to get the fp32 weights, we supply a special script that can do offline consolidation. It requires no configuration files or GPUs. Here is an example of its usage:
Expand Down
2 changes: 1 addition & 1 deletion docs/code-docs/source/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Gradient Accumulation

Model Saving
------------
.. autofunction:: deepspeed.DeepSpeedEngine.save_fp16_model
.. autofunction:: deepspeed.DeepSpeedEngine.save_16bit_model


Additionally when a DeepSpeed checkpoint is created, a script ``zero_to_fp32.py`` is added there which can be used to reconstruct fp32 master weights into a single pytorch ``state_dict`` file.