Skip to content

Commit

Permalink
[Deepspeed] add support for bf16 mode (#14569)
Browse files Browse the repository at this point in the history
* [WIP] add support for bf16 mode

* prep for bf16

* prep for bf16

* fix; zero2/bf16 is ok

* check bf16 is available

* test fixes

* enable zero3_bf16

* config files

* docs

* split stage_dtype; merge back to non-dtype-specific config file

* fix doc

* cleanup

* cleanup

* bfloat16 => bf16 to match the PR changes

* s/zero_gather_fp16_weights_on_model_save/zero_gather_16bit_weights_on_model_save/; s/save_fp16_model/save_16bit_model/

* test fixes/skipping

* move

* fix

* Update docs/source/main_classes/deepspeed.mdx

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* backticks

* cleanup

* cleanup

* cleanup

* new version

* add note about grad accum in bf16

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
stas00 and sgugger authored Mar 12, 2022
1 parent c1f209d commit 580dd87
Show file tree
Hide file tree
Showing 10 changed files with 214 additions and 113 deletions.
65 changes: 52 additions & 13 deletions docs/source/main_classes/deepspeed.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ cat <<'EOT' > ds_config_zero3.json
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_fp16_weights_on_model_save": true
"stage3_gather_16bit_weights_on_model_save": true
},

"gradient_accumulation_steps": "auto",
Expand Down Expand Up @@ -652,7 +652,7 @@ The following is an example of configuration for ZeRO stage 3:
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_fp16_weights_on_model_save": true
"stage3_gather_16bit_weights_on_model_save": true
}
}
```
Expand Down Expand Up @@ -691,7 +691,7 @@ The following configuration values depend on the model's hidden size:
therefore set these values to `auto` and the [`Trainer`] will automatically assign the recommended
values. But, of course, feel free to set these explicitly as well.

`stage3_gather_fp16_weights_on_model_save` enables model fp16 weights consolidation when model gets saved. With large
`stage3_gather_16bit_weights_on_model_save` enables model fp16 weights consolidation when model gets saved. With large
models and multiple GPUs this is an expensive operation both in terms of memory and speed. It's currently required if
you plan to resume the training. Watch out for future updates that will remove this limitation and make things more
flexible.
Expand Down Expand Up @@ -760,8 +760,8 @@ The following configuration example enables NVMe to offload both optimizer state
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_fp16_weights_on_model_save": true
}
"stage3_gather_16bit_weights_on_model_save": true
},
}
```

Expand Down Expand Up @@ -966,7 +966,7 @@ Here is a full ZeRO-3 auto-configuration file `ds_config_zero3.json`:
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_fp16_weights_on_model_save": true
"stage3_gather_16bit_weights_on_model_save": true
},

"gradient_accumulation_steps": "auto",
Expand Down Expand Up @@ -1029,7 +1029,7 @@ values look like, but we highly recommend using the one with multiple `auto` set
"stage3_param_persistence_threshold": 1e4,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_fp16_weights_on_model_save": true
"stage3_gather_16bit_weights_on_model_save": true
},

"steps_per_print": 2000,
Expand Down Expand Up @@ -1232,6 +1232,7 @@ the much more efficient tf32 format for some operations, but the results will st
benchmarks, please, see [TensorFloat-32(TF32) on Ampere devices](https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices). The document includes
instructions on how to disable this automatic conversion if for some reason you prefer not to use it.

With the 🤗 Trainer you can use `--tf32` to enable it, or disable it with `--tf32 0` or `--no_tf32`. By default the PyTorch default is used.



Expand All @@ -1241,7 +1242,9 @@ instructions on how to disable this automatic conversion if for some reason you

You can use automatic mixed precision with either a pytorch-like AMP way or the apex-like way:

To configure pytorch AMP-like mode set:
### fp16

To configure pytorch AMP-like mode with fp16 (float16) set:

```json
{
Expand All @@ -1259,7 +1262,7 @@ To configure pytorch AMP-like mode set:
and the [`Trainer`] will automatically enable or disable it based on the value of
`args.fp16_backend`. The rest of config values are up to you.

This mode gets enabled when `--fp16 --fp16_backend amp` command line args are passed.
This mode gets enabled when `--fp16 --fp16_backend amp` or `--fp16_full_eval` command line args are passed.

You can also enable/disable this mode explicitly:

Expand All @@ -1281,6 +1284,43 @@ configuration.

Here is the [documentation](https://www.deepspeed.ai/docs/config-json/#fp16-training-options).

### bf16

If bf16 (bfloat16) is desired instead of fp16 then the following configuration section is to be used:

```json
{
"bf16": {
"enabled": "auto"
}
}
```

bf16 has the same dynamic range as fp32 and thus doesn't require loss scaling.

This mode gets enabled when `--bf16` or `--bf16_full_eval` command line args are passed.

You can also enable/disable this mode explicitly:

```json
{
"bf16": {
"enabled": true
}
}
```

<Tip>

As of `deepspeed==0.6.0` the bf16 support is new and experimental.

If you use [gradient accumulation](#gradient-accumulation) with bf16-enabled, you need to be aware that it'll accumulate gradients in bf16, which may not be what you want due to this format's low precision, as it may lead to a lossy accumulation.

</Tip>


### apex

To configure apex AMP-like mode set:

```json
Expand Down Expand Up @@ -1411,15 +1451,14 @@ When a model is saved under ZeRO-2, you end up having the normal `pytorch_model.
they are only the fp16 version of the weights.

Under ZeRO-3, things are much more complicated, since the model weights are partitioned out over multiple GPUs,
therefore `"stage3_gather_fp16_weights_on_model_save": true` is required to get the `Trainer` to save the fp16
version of the weights. If this setting is `False` ``pytorch_model.bin` won't be created. This is because by default DeepSpeed's `state_dict` contains a placeholder and not the real weights. If we were to save this `state_dict`` it
won't be possible to load it back.
therefore `"stage3_gather_16bit_weights_on_model_save": true` is required to get the `Trainer` to save the fp16
version of the weights. If this setting is `False` `pytorch_model.bin` won't be created. This is because by default DeepSpeed's `state_dict` contains a placeholder and not the real weights. If we were to save this `state_dict` it won't be possible to load it back.


```json
{
"zero_optimization": {
"stage3_gather_fp16_weights_on_model_save": true
"stage3_gather_16bit_weights_on_model_save": true
}
}
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_fp16_weights_on_model_save": true
"stage3_gather_16bit_weights_on_model_save": true
},

"gradient_accumulation_steps": "auto",
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
"cookiecutter==1.7.2",
"dataclasses",
"datasets",
"deepspeed>=0.5.9",
"deepspeed>=0.6.0",
"fairscale>0.3",
"faiss-cpu",
"fastapi",
Expand Down
27 changes: 18 additions & 9 deletions src/transformers/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self, config_file_or_dict):

# zero stage - this is done as early as possible, before model is created, to allow
# ``is_deepspeed_zero3_enabled`` query and getting to the early deepspeed config object
# during ``zero.Init()`` which needs whether fp16 is enabled, dtype, etc.
# during ``zero.Init()`` which needs to know the dtype, and some other hparams.
self._stage = self.get_value("zero_optimization.stage", -1)

# offload
Expand Down Expand Up @@ -169,10 +169,12 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):

def __init__(self, config_file_or_dict):
super().__init__(config_file_or_dict)
self._dtype = torch.float16
self._dtype = None
self.mismatches = []

def dtype(self):
if self._dtype is None:
raise ValueError("trainer_config_process() wasn't called yet to tell dtype")
return self._dtype

def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True):
Expand Down Expand Up @@ -228,26 +230,33 @@ def trainer_config_process(self, args):
# total_num_steps - will get set in trainer_config_finalize

# fp16
if args.fp16:
if args.fp16 or args.fp16_full_eval:
fp16_backend = "apex" if args.fp16_backend == "apex" else "amp"
else:
fp16_backend = None

# amp: similar to the pytorch native amp - it has a bunch of optional params but we won't set
# any here unless the user did the work
self.fill_match("fp16.enabled", fp16_backend == "amp", "fp16+fp16_backend(amp)")
self.fill_match(
"fp16.enabled",
((args.fp16 or args.fp16_full_eval) and fp16_backend == "amp"),
"fp16|fp16_full_eval+fp16_backend(amp)",
)

# apex: delegates amp work to apex (which needs to be available), but it cannot be used with any
# ZeRO features
self.fill_match("amp.enabled", fp16_backend == "apex", "fp16+fp16_backend(apex)")
self.fill_match("amp.opt_level", args.fp16_opt_level, "fp16_opt_level")

# only if we have an explicit fp16.enabled = False then it's fp32, if it's True or this
# whole config section is missing then the fallback is fp16
if self.is_false("fp16.enabled"):
self.fill_match("bf16.enabled", (args.bf16 or args.bf16_full_eval), "bf16|bf16_full_eval")

# deepspeed's default mode is fp16 unless there is a config that says differently
if self.is_true("bfoat16.enabled"):
self._dtype = torch.bfloat16
elif self.is_false("fp16.enabled"):
self._dtype = torch.float32
# later there will be other dtypes besides just fp16 and fp32
# also not quite sure what dtype should be under apex, defaulting to fp16 for now
else:
self._dtype = torch.float16

def trainer_config_finalize(self, args, model, num_training_steps):
"""
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"cookiecutter": "cookiecutter==1.7.2",
"dataclasses": "dataclasses",
"datasets": "datasets",
"deepspeed": "deepspeed>=0.5.9",
"deepspeed": "deepspeed>=0.6.0",
"fairscale": "fairscale>0.3",
"faiss-cpu": "faiss-cpu",
"fastapi": "fastapi",
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1687,7 +1687,7 @@ def _save_checkpoint(self, model, trial, metrics=None):
self.save_model(output_dir, _internal_call=True)
if self.deepspeed:
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
# config `stage3_gather_fp16_weights_on_model_save` is True
# config `stage3_gather_16bit_weights_on_model_save` is True
self.deepspeed.save_checkpoint(output_dir)

# Save optimizer and scheduler
Expand Down Expand Up @@ -2101,12 +2101,12 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
# logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
os.remove(file)

# now save the real model if stage3_gather_fp16_weights_on_model_save=True
# now save the real model if stage3_gather_16bit_weights_on_model_save=True
# if false it will not be saved.
# This must be called on all ranks
if not self.deepspeed.save_fp16_model(output_dir, WEIGHTS_NAME):
if not self.deepspeed.save_16bit_model(output_dir, WEIGHTS_NAME):
logger.warning(
"deepspeed.save_fp16_model didn't save the model, since stage3_gather_fp16_weights_on_model_save=false. "
"deepspeed.save_16bit_model didn't save the model, since stage3_gather_16bit_weights_on_model_save=false. "
"Saving the full checkpoint instead, use zero_to_fp32.py to recover weights"
)
self.deepspeed.save_checkpoint(output_dir)
Expand Down
4 changes: 4 additions & 0 deletions tests/deepspeed/ds_config_zero2.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
"min_loss_scale": 1
},

"bf16": {
"enabled": "auto"
},

"optimizer": {
"type": "AdamW",
"params": {
Expand Down
6 changes: 5 additions & 1 deletion tests/deepspeed/ds_config_zero3.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
"min_loss_scale": 1
},

"bf16": {
"enabled": "auto"
},

"optimizer": {
"type": "AdamW",
"params": {
Expand Down Expand Up @@ -45,7 +49,7 @@
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_fp16_weights_on_model_save": true
"stage3_gather_16bit_weights_on_model_save": true
},

"gradient_accumulation_steps": "auto",
Expand Down
Loading

0 comments on commit 580dd87

Please sign in to comment.