-
Notifications
You must be signed in to change notification settings - Fork 1k
/
Copy pathaccelerator.py
1206 lines (1063 loc) · 53.5 KB
/
accelerator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import gc
import math
import os
import sys
import warnings
from contextlib import contextmanager
from typing import List, Optional, Union
import torch
from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
from .data_loader import prepare_data_loader
from .logging import get_logger
from .optimizer import AcceleratedOptimizer
from .scheduler import AcceleratedScheduler
from .state import AcceleratorState, GradientState
from .tracking import LOGGER_TYPE_TO_CLASS, GeneralTracker, filter_trackers
from .utils import (
DeepSpeedPlugin,
DistributedDataParallelKwargs,
DistributedType,
FullyShardedDataParallelPlugin,
GradScalerKwargs,
InitProcessGroupKwargs,
KwargsHandler,
LoggerType,
PrecisionType,
RNGType,
compare_versions,
convert_outputs_to_fp32,
extract_model_from_parallel,
gather,
get_pretty_name,
is_bf16_available,
is_deepspeed_available,
is_torch_version,
is_tpu_available,
pad_across_processes,
recursively_apply,
reduce,
save,
wait_for_everyone,
)
if is_deepspeed_available():
import deepspeed
from .utils import (
DeepSpeedEngineWrapper,
DeepSpeedOptimizerWrapper,
DeepSpeedSchedulerWrapper,
DummyOptim,
DummyScheduler,
)
if is_tpu_available(check_device=False):
import torch_xla.distributed.xla_multiprocessing as xmp
logger = get_logger(__name__)
class Accelerator:
"""
Creates an instance of an accelerator for distributed training (on multi-GPU, TPU) or mixed precision training.
Args:
device_placement (`bool`, *optional*, defaults to `True`):
Whether or not the accelerator should put objects on device (tensors yielded by the dataloader, model,
etc...).
split_batches (`bool`, *optional*, defaults to `False`):
Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If
`True` the actual batch size used will be the same on any kind of distributed processes, but it must be a
round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set
in your script multiplied by the number of processes.
mixed_precision (`str`, *optional*):
Whether or not to use mixed precision training (fp16 or bfloat16). Choose from 'no','fp16','bf16'. Will
default to the value in the environment variable `MIXED_PRECISION`, which will use the default value in the
accelerate config of the current system or the flag passed with the `accelerate.launch` command. 'fp16'
requires pytorch 1.6 or higher. 'bf16' requires pytorch 1.10 or higher.
gradient_accumulation_steps (`int`, *optional*, default to 1):
The number of steps that should pass before gradients are accumulated. A number > 1 should be combined with
`Accelerator.accumulate`.
cpu (`bool`, *optional*):
Whether or not to force the script to execute on CPU. Will ignore GPU available if set to `True` and force
the execution on one process only.
deepspeed_plugin (`DeepSpeedPlugin`, *optional*):
Tweak your DeepSpeed related args using this argument. This argument is optional and can be configured
directly using *accelerate config*
fsdp_plugin (`FullyShardedDataParallelPlugin`, *optional*):
Tweak your FSDP related args using this argument. This argument is optional and can be configured directly
using *accelerate config*
rng_types (list of `str` or [`~utils.RNGType`]):
The list of random number generators to synchronize at the beginning of each iteration in your prepared
dataloaders. Should be one or several of:
- `"torch"`: the base torch random number generator
- `"cuda"`: the CUDA random number generator (GPU only)
- `"xla"`: the XLA random number generator (TPU only)
- `"generator"`: the `torch.Generator` of the sampler (or batch sampler if there is no sampler in your
dataloader) or of the iterable dataset (if it exists) if the underlying dataset is of that type.
Will default to `["torch"]` for PyTorch versions <=1.5.1 and `["generator"]` for PyTorch versions >= 1.6.
log_with (list of `str`, [`~utils.LoggerType`] or [`~tracking.GeneralTracker`], *optional*):
A list of loggers to be setup for experiment tracking. Should be one or several of:
- `"all"`
- `"tensorboard"`
- `"wandb"`
- `"comet_ml"`
If `"all`" is selected, will pick up all available trackers in the environment and intialize them. Can also
accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`.
logging_dir (`str`, `os.PathLike`, *optional*):
A path to a directory for storing logs of locally-compatible loggers.
dispatch_batches (`bool`, *optional*):
If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process
and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose
underlying dataset is an `IterableDataset`, `False` otherwise.
step_scheduler_with_optimizer (`bool`, *optional`, defaults to `True`):
Set `True` if the learning rate scheduler is stepped at the same time as the optimizer, `False` if only
done under certain circumstances (at the end of each epoch, for instance).
kwargs_handlers (`List[KwargHandler]`, *optional*)
A list of `KwargHandler` to customize how the objects related to distributed training or mixed precision
are created. See [kwargs](kwargs) for more information.
Attributes
- **device** (`torch.device`) -- The device to use.
- **state** ([`~state.AcceleratorState`]) -- The distributed setup state.
"""
def __init__(
self,
device_placement: bool = True,
split_batches: bool = False,
fp16: bool = None,
mixed_precision: Union[PrecisionType, str] = None,
gradient_accumulation_steps: int = 1,
cpu: bool = False,
deepspeed_plugin: DeepSpeedPlugin = None,
fsdp_plugin: FullyShardedDataParallelPlugin = None,
rng_types: Optional[List[Union[str, RNGType]]] = None,
log_with: Optional[List[Union[str, LoggerType, GeneralTracker]]] = None,
logging_dir: Optional[Union[str, os.PathLike]] = None,
dispatch_batches: Optional[bool] = None,
step_scheduler_with_optimizer: bool = True,
kwargs_handlers: Optional[List[KwargsHandler]] = None,
):
self.logging_dir = logging_dir
trackers = filter_trackers(log_with, self.logging_dir)
if len(trackers) < 1 and log_with is not None:
warnings.warn(f"`log_with={log_with}` was passed but no supported trackers are currently installed.")
self.log_with = trackers
if mixed_precision is not None:
mixed_precision = str(mixed_precision)
if mixed_precision not in PrecisionType:
raise ValueError(
f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}"
)
if fp16:
warnings.warn('fp16=True is deprecated. Use mixed_precision="fp16" instead.', DeprecationWarning)
mixed_precision = "fp16"
if deepspeed_plugin is None: # init from env variables
deepspeed_plugin = DeepSpeedPlugin() if os.environ.get("USE_DEEPSPEED", "false") == "true" else None
else:
assert isinstance(
deepspeed_plugin, DeepSpeedPlugin
), "`deepspeed_plugin` must be a DeepSpeedPlugin object."
os.environ["USE_DEEPSPEED"] = "true" # use DeepSpeed if plugin is provided
if deepspeed_plugin:
if not is_deepspeed_available():
raise ImportError("DeepSpeed is not installed => run `pip install deepspeed` or build it from source.")
if compare_versions("deepspeed", "<", "0.6.5"):
raise ImportError("DeepSpeed version must be >= 0.6.5. Please update DeepSpeed.")
mixed_precision = os.environ.get("MIXED_PRECISION", "no") if mixed_precision is None else mixed_precision
deepspeed_plugin.set_mixed_precision(mixed_precision)
deepspeed_plugin.set_deepspeed_weakref()
if os.environ.get("USE_FSDP", "false") == "true" or isinstance(fsdp_plugin, FullyShardedDataParallelPlugin):
if is_torch_version("<", "1.12.0"):
raise ValueError("FSDP requires PyTorch >= 1.12.0")
if fsdp_plugin is None: # init from env variables
fsdp_plugin = FullyShardedDataParallelPlugin() if os.environ.get("USE_FSDP", "false") == "true" else None
else:
if not isinstance(fsdp_plugin, FullyShardedDataParallelPlugin):
raise TypeError("`fsdp_plugin` must be a FullyShardedDataParallelPlugin object.")
os.environ["USE_FSDP"] = "true" # use FSDP if plugin is provided
# Kwargs handlers
self.ddp_handler = None
self.scaler_handler = None
self.init_handler = None
if kwargs_handlers is not None:
for handler in kwargs_handlers:
assert isinstance(handler, KwargsHandler), f"Unsupported kwargs handler passed: {handler}."
if isinstance(handler, DistributedDataParallelKwargs):
if self.ddp_handler is not None:
raise ValueError("You can only pass one `DistributedDataParallelKwargs` in `kwargs_handler`.")
else:
self.ddp_handler = handler
elif isinstance(handler, GradScalerKwargs):
if self.scaler_handler is not None:
raise ValueError("You can only pass one `GradScalerKwargs` in `kwargs_handler`.")
else:
self.scaler_handler = handler
elif isinstance(handler, InitProcessGroupKwargs):
if self.init_handler is not None:
raise ValueError("You can only pass one `InitProcessGroupKwargs` in `kwargs_handler`.")
else:
self.init_handler = handler
kwargs = self.init_handler.to_kwargs() if self.init_handler is not None else {}
self.state = AcceleratorState(
mixed_precision=mixed_precision,
cpu=cpu,
deepspeed_plugin=deepspeed_plugin,
fsdp_plugin=fsdp_plugin,
_from_accelerator=True,
**kwargs,
)
if (
(mixed_precision != "bf16")
and getattr(self.state, "downcast_bfloat", False)
and (self.state.distributedType != DistributedType.TPU)
):
raise ValueError("Can only use `downcast_bf16` when using `mixed_precision='bf16'` and on a TPU")
if gradient_accumulation_steps > 1:
if self.state.distributed_type == DistributedType.TPU:
raise NotImplementedError(
"Gradient accumulation on TPU is not supported. Pass in `gradient_accumulation_steps=1`"
)
self.gradient_accumulation_steps = gradient_accumulation_steps
self.device_placement = device_placement
self.split_batches = split_batches
self.dispatch_batches = dispatch_batches
if dispatch_batches is True and is_torch_version("<", "1.8.0"):
raise ImportError(
"Using `DataLoaderDispatcher` requires PyTorch 1.8.0 minimum. You have {torch.__version__}."
)
self.step_scheduler_with_optimizer = step_scheduler_with_optimizer
# Mixed precision attributes
self.scaler = None
self.native_amp = False
err = "{mode} mixed precision requires {requirement}"
if self.state.mixed_precision == "fp16":
self.native_amp = is_torch_version(">=", "1.6")
if not self.native_amp:
raise ValueError(err.format(mode="fp16", requirement="PyTorch >= 1.6"))
if not torch.cuda.is_available():
raise ValueError(err.format(mode="fp16", requirement="a GPU"))
kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}
if self.distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
self.scaler = ShardedGradScaler(**kwargs)
else:
self.scaler = torch.cuda.amp.GradScaler(**kwargs)
elif self.state.mixed_precision == "bf16" and self.distributed_type != DistributedType.FSDP:
self.native_amp = is_bf16_available(True)
if mixed_precision == "bf16" and not self.native_amp and not is_tpu_available():
raise ValueError(err.format(mode="bf16", requirement="PyTorch >= 1.10 and a supported device."))
# Only on the GPU do we care about scaling the gradients
if torch.cuda.is_available() and self.device.type != "cpu":
kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}
self.scaler = torch.cuda.amp.GradScaler(**kwargs)
# Start of internal step tracking
self.step = 0
self.gradient_state = GradientState()
# Internal references to the training objects
self._optimizers = []
self._models = []
self._schedulers = []
self._custom_objects = []
# RNG Types
self.rng_types = rng_types
if self.rng_types is None:
self.rng_types = ["torch"] if is_torch_version("<=", "1.5.1") else ["generator"]
@property
def use_distributed(self):
return self.distributed_type != DistributedType.NO and self.num_processes > 1
@property
def distributed_type(self):
return self.state.distributed_type
@property
def num_processes(self):
return self.state.num_processes
@property
def process_index(self):
return self.state.process_index
@property
def local_process_index(self):
return self.state.local_process_index
@property
def device(self):
return self.state.device
@property
def is_main_process(self):
"""True for one process only."""
return self.process_index == 0
@property
def is_local_main_process(self):
"""True for one process per server."""
return self.local_process_index == 0
@property
def use_fp16(self):
return self.mixed_precision != "no"
@property
def mixed_precision(self):
if self.distributed_type == DistributedType.DEEPSPEED:
config = self.state.deepspeed_plugin.deepspeed_config
if config.get("fp16", {}).get("enabled", False):
mixed_precision = "fp16"
elif config.get("bf16", {}).get("enabled", False):
mixed_precision = "bf16"
else:
mixed_precision = "no"
else:
mixed_precision = self.state.mixed_precision
return mixed_precision
@contextmanager
def local_main_process_first(self):
"""
Lets the local main process go inside a with block.
The other processes will enter the with block after the main process exits.
"""
yield from self._goes_first(self.is_local_main_process)
@contextmanager
def main_process_first(self):
"""
Lets the main process go first inside a with block.
The other processes will enter the with block after the main process exits.
"""
yield from self._goes_first(self.is_main_process)
def _goes_first(self, is_main):
if not is_main:
self.wait_for_everyone()
yield
if is_main:
self.wait_for_everyone()
@contextmanager
def no_sync(self, model):
"""
A context manager to disable gradient synchronizations across DDP processes by calling
`torch.nn.parallel.DistributedDataParallel.no_sync`.
If `model` is not in DDP, this context manager does nothing
Args:
model (`torch.nn.Module`):
PyTorch Module that was prepared with `Accelerator.prepare`
"""
context = contextlib.nullcontext
if self.use_distributed:
context = getattr(model, "no_sync", context)
with context():
yield
def _do_sync(self):
"Sets the right `sync_gradients` context and either resets or increases `self.step`"
if self.gradient_state.end_of_dataloader:
self.step = 0
self.gradient_state._set_sync_gradients(True)
else:
self.step += 1
self.gradient_state._set_sync_gradients((self.step % self.gradient_accumulation_steps) == 0)
@property
def sync_gradients(self):
return self.gradient_state.sync_gradients
@contextmanager
def accumulate(self, model):
"""
A context manager that will lightly wrap around and perform gradient accumulation automatically
Args:
model (`torch.nn.Module`):
PyTorch Module that was prepared with `Accelerator.prepare`
"""
self._do_sync()
if self.sync_gradients:
context = contextlib.nullcontext
else:
context = self.no_sync
with context(model):
yield
def print(self, *args, **kwargs):
"""
Use in replacement of `print()` to only print once per server.
"""
if self.is_local_main_process:
print(*args, **kwargs)
def _prepare_one(self, obj, first_pass=False):
# First pass of preparation: DataLoader, model, optimizer
if first_pass:
if isinstance(obj, torch.utils.data.DataLoader):
return self.prepare_data_loader(obj)
elif isinstance(obj, torch.nn.Module):
self._models.append(obj)
return self.prepare_model(obj)
elif isinstance(obj, torch.optim.Optimizer):
optimizer = self.prepare_optimizer(obj)
self._optimizers.append(optimizer)
return optimizer
# Second pass of preparation: LR scheduler (which need the full list of optimizers)
elif isinstance(obj, torch.optim.lr_scheduler._LRScheduler):
scheduler = self.prepare_scheduler(obj)
self._schedulers.append(scheduler)
return scheduler
# Return the unprocessed object if previous criteria was not met
return obj
def _prepare_fsdp(self, *args):
result = []
for obj in args:
if isinstance(obj, torch.nn.Module):
model = obj
break
optimizers = []
self._schedulers = []
self._models = []
intermediate_result = []
for obj in args:
if isinstance(obj, torch.optim.Optimizer):
if len(obj.param_groups) > 1:
logger.warn(
"FSDP Warning: When using FSDP, several parameter groups will be conflated into "
"a single one due to nested module wrapping and parameter flattening."
)
optimizer = obj.optimizer.__class__(model.parameters(), **obj.optimizer.defaults)
obj = self.prepare_optimizer(optimizer)
optimizers.append(obj)
elif isinstance(obj, torch.nn.Module):
self._models.append(obj)
intermediate_result.append(obj)
for obj in intermediate_result:
if isinstance(obj, AcceleratedScheduler):
obj.optimizer = optimizers
for i, opt in enumerate(self._optimizers):
if getattr(obj.scheduler, "optimizer", None) == opt.optimizer:
obj.scheduler.optimizer = optimizers[i]
obj.optimizers = [optimizers[i]]
break
self._schedulers.append(obj)
result.append(obj)
self._optimizers = optimizers
return tuple(result)
def prepare(self, *args):
"""
Prepare all objects passed in `args` for distributed training and mixed precision, then return them in the same
order.
Accepts the following type of objects:
- `torch.utils.data.DataLoader`: PyTorch Dataloader
- `torch.nn.Module`: PyTorch Module
- `torch.optim.Optimizer`: PyTorch Optimizer
"""
if self.distributed_type == DistributedType.FSDP:
model_count = 0
optimizer_present = False
for obj in args:
if isinstance(obj, torch.nn.Module):
model_count += 1
if isinstance(obj, torch.optim.Optimizer):
optimizer_present = True
if model_count > 1 and optimizer_present:
raise ValueError(
"For FSDP to work with multiple models (>1), "
"prepare must be called for all the models before optimizers are created"
)
elif model_count == 1 and optimizer_present:
logger.warn(
"FSDP Warning: When using FSDP, "
"it is efficient and recommended to call prepare for the model before creating the optimizer"
)
# On TPUs, putting the model on the XLA device will create new parameters, so the corresponding optimizer will
# have parameters disconnected from the model (so no training :-( ).
# If the model and optimizer have parameters on different devices we raise an error.
if self.distributed_type == DistributedType.TPU:
model_device, optimizer_device = self._get_devices()
if model_device is not None and optimizer_device is not None and model_device != optimizer_device:
raise ValueError(
"The model and the optimizer parameters are not on the same device, which probably means you "
"created an optimizer around your model **before** putting on the device. Make sure the line "
"model.to(device) is before the optimizer creation in your script or remove it entirely and use "
"the flag default value for `devicement_placement` in your `Accelerator` to let it handle that "
"part for you."
)
# If we're dealing with device placement, this deals with that by...
tpu_should_fix_optimizer = self.device_placement and self.distributed_type == DistributedType.TPU
if tpu_should_fix_optimizer:
# 1. grabbing old model parameters
old_named_params = self._get_named_parameters(*args)
if self.distributed_type == DistributedType.DEEPSPEED:
result = self._prepare_deepspeed(*args)
else:
result = tuple(self._prepare_one(obj, first_pass=True) for obj in args)
result = tuple(self._prepare_one(obj) for obj in result)
if tpu_should_fix_optimizer:
# 2. grabbing new model parameters
new_named_params = self._get_named_parameters(*result)
# 3. building a map from the first to the second
mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
# 4. using that map to update the parameters of the optimizer
for obj in result:
if isinstance(obj, torch.optim.Optimizer):
obj._switch_parameters(mapping)
if self.distributed_type == DistributedType.FSDP and model_count == 1 and optimizer_present:
result = self._prepare_fsdp(*result)
return result if len(result) > 1 else result[0]
def prepare_model(self, model):
if self.device_placement and self.distributed_type != DistributedType.FSDP:
model = model.to(self.device)
if self.distributed_type == DistributedType.MULTI_GPU:
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[self.local_process_index], output_device=self.local_process_index, **kwargs
)
elif self.distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
# Check if the model is already a FSDP model due to `Manual Wrapping` and if so,
# don't wrap it again
if type(model) != FSDP:
self.state.fsdp_plugin.set_auto_wrap_policy(model)
fsdp_plugin = self.state.fsdp_plugin
model = FSDP(
model,
sharding_strategy=fsdp_plugin.sharding_strategy,
cpu_offload=fsdp_plugin.cpu_offload,
auto_wrap_policy=fsdp_plugin.auto_wrap_policy,
backward_prefetch=fsdp_plugin.backward_prefetch,
mixed_precision=fsdp_plugin.mixed_precision_policy,
ignored_modules=fsdp_plugin.ignored_modules,
)
if not fsdp_plugin.cpu_offload.offload_params:
model.to(self.device)
elif self.distributed_type == DistributedType.MULTI_CPU:
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
if self.native_amp:
if self.mixed_precision == "fp16" and is_torch_version(">=", "1.10"):
model.forward = torch.cuda.amp.autocast(dtype=torch.float16)(model.forward)
elif self.mixed_precision == "bf16" and self.distributed_type != DistributedType.TPU:
device_type = "cuda" if torch.cuda.is_available() else "cpu"
model.forward = torch.autocast(device_type=device_type, dtype=torch.bfloat16)(model.forward)
else:
model.forward = torch.cuda.amp.autocast()(model.forward)
model.forward = convert_outputs_to_fp32(model.forward)
if self.distributed_type == DistributedType.TPU and self.state.fork_launched:
model = xmp.MpModelWrapper(model).to(self.device)
return model
def _prepare_deepspeed(self, *args):
deepspeed_plugin = self.state.deepspeed_plugin
result = [
self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) else obj
for obj in args
]
batch_sizes = [obj.batch_size for obj in args if hasattr(obj, "batch_size")]
if self.split_batches:
batch_sizes = [batch_size // self.num_processes for batch_size in batch_sizes]
if len(batch_sizes) == 0:
raise ValueError(
"You must specify a training or evaluation dataloader in `accelerate.prepare()` when using DeepSpeed."
)
batch_size_per_device = min(batch_sizes) if deepspeed_plugin.is_train_batch_min else max(batch_sizes)
if len(batch_sizes) > 1:
logger.info(
"Since you passed both train and evaluation dataloader, `is_train_batch_min` (here "
f"{deepspeed_plugin.is_train_batch_min} will decide the `train_batch_size` ({batch_size_per_device})."
)
config_kwargs = {
"train_micro_batch_size_per_gpu": batch_size_per_device,
"train_batch_size": batch_size_per_device
* deepspeed_plugin.deepspeed_config["gradient_accumulation_steps"]
* self.num_processes,
"gradient_clipping": 1.0,
"zero_optimization.stage3_gather_16bit_weights_on_model_save": False,
}
model = None
optimizer = None
scheduler = None
for obj in result:
if isinstance(obj, torch.nn.Module):
model = obj
elif isinstance(obj, (torch.optim.Optimizer, DummyOptim)):
optimizer = obj
elif (isinstance(obj, (torch.optim.lr_scheduler._LRScheduler, DummyScheduler))) or (
type(obj).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES
):
scheduler = obj
if optimizer is not None:
if "optimizer" in deepspeed_plugin.deepspeed_config and not isinstance(optimizer, (DummyOptim)):
raise ValueError(
"You cannot specify an optimizer in the config file and in the code at the same time. "
"Please remove the optimizer from the config file or "
"create `accelerate.utils.DummyOptim` in the code."
)
elif "optimizer" not in deepspeed_plugin.deepspeed_config and isinstance(optimizer, (DummyOptim)):
raise ValueError(
"You cannot create a `DummyOptim` without specifying an optimizer in the config file."
)
if isinstance(optimizer, (torch.optim.Optimizer)):
deepspeed_plugin.deepspeed_config["zero_allow_untested_optimizer"] = True
if scheduler is not None:
if "scheduler" in deepspeed_plugin.deepspeed_config and not isinstance(scheduler, (DummyScheduler)):
raise ValueError(
"You cannot specify a scheduler in the config file and in the code at the same time. "
"Please remove the scheduler from the config file or "
"create `accelerate.utils.DummyScheduler` in the code."
)
elif "scheduler" not in deepspeed_plugin.deepspeed_config and isinstance(scheduler, (DummyScheduler)):
raise ValueError(
"You cannot create a `DummyScheduler` without specifying a scheduler in the config file."
)
if optimizer is not None and scheduler is not None:
if isinstance(optimizer, (DummyOptim)) and not isinstance(scheduler, (DummyScheduler)):
raise ValueError(
"You can only specify `accelerate.utils.DummyScheduler` in the code when using "
"`accelerate.utils.DummyOptim`."
)
if model is not None:
if hasattr(model, "config") and hasattr(model.config, "hidden_size"):
hidden_size = model.config.hidden_size
config_kwargs.update(
{
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
}
)
if isinstance(optimizer, (DummyOptim)):
config_kwargs.update(
{"optimizer.params.lr": optimizer.lr, "optimizer.params.weight_decay": optimizer.weight_decay}
)
if isinstance(scheduler, (DummyScheduler)):
config_kwargs.update(
{
"scheduler.params.warmup_min_lr": 0,
"scheduler.params.warmup_max_lr": scheduler.optimizer.lr,
"scheduler.params.warmup_num_steps": scheduler.warmup_num_steps,
}
)
if scheduler.total_num_steps is not None:
config_kwargs["scheduler.params.total_num_steps"] = (
math.ceil(scheduler.total_num_steps / self.num_processes)
if not self.split_batches
else scheduler.total_num_steps
)
deepspeed_plugin.deepspeed_config_process(must_match=False, **config_kwargs)
self.deepspeed_config = deepspeed_plugin.deepspeed_config
kwargs = dict(model=model, config_params=self.deepspeed_config)
if optimizer is not None:
if isinstance(optimizer, (DummyOptim)):
kwargs["model_parameters"] = optimizer.params
else:
kwargs["optimizer"] = optimizer
if scheduler is not None:
if type(scheduler).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES:
kwargs["lr_scheduler"] = scheduler
engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
if optimizer is not None:
optimizer = DeepSpeedOptimizerWrapper(optimizer)
if scheduler is not None:
if lr_scheduler is None:
scheduler = AcceleratedScheduler(
scheduler,
optimizer,
step_with_optimizer=self.step_scheduler_with_optimizer,
split_batches=self.split_batches,
)
else:
scheduler = DeepSpeedSchedulerWrapper(lr_scheduler, optimizer)
for i in range(len(result)):
if isinstance(result[i], torch.nn.Module):
result[i] = engine
elif isinstance(result[i], (torch.optim.Optimizer, DummyOptim)):
result[i] = optimizer
elif (isinstance(result[i], (torch.optim.lr_scheduler._LRScheduler, DummyScheduler))) or (
type(result[i]).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES
):
result[i] = scheduler
# pointing for deepspeed_engine_wrapped.backward()
self.deepspeed_engine_wrapped = DeepSpeedEngineWrapper(engine)
self._models.append(engine)
if optimizer is not None:
self._optimizers.append(optimizer)
if scheduler is not None:
self._schedulers.append(scheduler)
if len(self._models) > 1:
raise AssertionError(
"You can't use same `Accelerator()` instance with multiple models when using DeepSpeed"
)
return tuple(result)
def prepare_data_loader(self, data_loader):
return prepare_data_loader(
data_loader,
self.device,
num_processes=self.num_processes,
process_index=self.process_index,
split_batches=self.split_batches,
put_on_device=self.device_placement if self.distributed_type != DistributedType.TPU else False,
rng_types=self.rng_types.copy(),
dispatch_batches=self.dispatch_batches,
)
def prepare_optimizer(self, optimizer):
return AcceleratedOptimizer(optimizer, device_placement=self.device_placement, scaler=self.scaler)
def prepare_scheduler(self, scheduler):
# We try to find the optimizer associated with `scheduler`, the default is the full list.
optimizer = self._optimizers
for opt in self._optimizers:
if getattr(scheduler, "optimizer", None) == opt.optimizer:
optimizer = opt
break
return AcceleratedScheduler(
scheduler,
optimizer,
step_with_optimizer=self.step_scheduler_with_optimizer,
split_batches=self.split_batches,
)
def backward(self, loss, **kwargs):
"""
Use `accelerator.backward(loss)` in lieu of `loss.backward()`.
"""
loss /= self.gradient_accumulation_steps
if self.distributed_type == DistributedType.DEEPSPEED:
self.deepspeed_engine_wrapped.backward(loss, **kwargs)
elif self.scaler is not None:
self.scaler.scale(loss).backward(**kwargs)
else:
loss.backward(**kwargs)
def unscale_gradients(self, optimizer=None):
"""
Unscale the gradients in mixed precision training with AMP. This is a noop in all other settings.
Args:
optimizer (`torch.optim.Optimizer` or `List[torch.optim.Optimizer]`, *optional*):
The optimizer(s) for which to unscale gradients. If not set, will unscale gradients on all optimizers
that were passed to [`~Accelerator.prepare`].
"""
if self.use_fp16 and self.native_amp:
if optimizer is None:
# TODO: this unscales all optimizers where we should only unscale the one where parameters are.
optimizer = self._optimizers
elif not isinstance(optimizer, (tuple, list)):
optimizer = [optimizer]
for opt in optimizer:
while isinstance(opt, AcceleratedOptimizer):
opt = opt.optimizer
self.scaler.unscale_(opt)
def clip_grad_norm_(self, parameters, max_norm, norm_type=2):
"""
Should be used in place of `torch.nn.utils.clip_grad_norm_`.
"""
if self.distributed_type == DistributedType.FSDP:
self.unscale_gradients()
parameters = [p for p in parameters]
for model in self._models:
if parameters == [p for p in model.parameters()]:
model.clip_grad_norm_(max_norm, norm_type)
return
elif self.distributed_type == DistributedType.DEEPSPEED:
# `accelerator.backward(loss)` is doing that automatically. Therefore, it's implementation is not needed
return
self.unscale_gradients()
torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type)
def clip_grad_value_(self, parameters, clip_value):
"""
Should be used in place of `torch.nn.utils.clip_grad_value_`.
"""
if self.distributed_type in [DistributedType.DEEPSPEED, DistributedType.FSDP]:
raise Exception("DeepSpeed and FSDP do not support `clip_grad_value_`. Use `clip_grad_norm_` instead.")
self.unscale_gradients()
torch.nn.utils.clip_grad_value_(parameters, clip_value)
def gather(self, tensor):
"""
Gather the values in *tensor* across all processes and concatenate them on the first dimension. Useful to
regroup the predictions from all processes when doing evaluation.
Note:
This gather happens in all processes.
Args:
tensor (`torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`):
The tensors to gather across all processes.
Returns:
`torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`: The gathered tensor(s). Note that the
first dimension of the result is *num_processes* multiplied by the first dimension of the input tensors.
"""
return gather(tensor)
def gather_for_metrics(self, tensor, dataloader):
"""
Gathers `tensor` and potentially drops duplicates in the last batch if on a distributed system. Should be used
for gathering the inputs and targets for metric calculation.
Args:
tensor (`torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`):
The tensors for calculating metrics across all processes.
dataloader (`torch.utils.data.DataLoader`):
A dataloader prepared with `Accelerator.prepare`
"""
tensor = self.gather(tensor)
if self.use_distributed:
try:
# Then see if we're on the last batch of our eval dataloader
if self.gradient_state.end_of_dataloader:
# Last batch needs to be truncated on distributed systems as it contains additional samples
def _adjust_samples(tensor):
return tensor[: dataloader.total_dataset_length - self.gradient_state.samples_seen]
return recursively_apply(_adjust_samples, tensor)
else:
# Not at the end of the dataloader, no need to adjust the tensors
return tensor
except:
# Dataset had no length or raised an error
return tensor
return tensor
def reduce(self, tensor, reduction="sum"):
"""
Reduce the values in *tensor* across all processes based on *reduction*.
Note:
All processes get the reduced value.
Args:
tensor (`torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`):
The tensors to reduce across all processes.
reduction (`str`, *optional*, defaults to "sum"):
A reduction type, can be one of 'sum', 'mean', or 'none'. If 'none', will not perform any operation.
Returns:
`torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`: The reduced tensor(s).
"""
return reduce(tensor, reduction)
def pad_across_processes(self, tensor, dim=0, pad_index=0, pad_first=False):
"""
Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
they can safely be gathered.
Args:
tensor (nested list/tuple/dictionary of `torch.Tensor`):
The data to gather.
dim (`int`, *optional*, defaults to 0):
The dimension on which to pad.
pad_index (`int`, *optional*, defaults to 0):
The value with which to pad.
pad_first (`bool`, *optional*, defaults to `False`):
Whether to pad at the beginning or the end.
"""
return pad_across_processes(tensor, dim=dim, pad_index=pad_index, pad_first=pad_first)
def unwrap_model(self, model):
"""
Unwraps the `model` from the additional layer possible added by [`~Accelerator.prepare`]. Useful before saving
the model.
Args:
model (`torch.nn.Module`):
The model to unwrap.
"""
return extract_model_from_parallel(model)
def wait_for_everyone(self):
"""
Will stop the execution of the current process until every other process has reached that point (so this does
nothing when the script is only run in one process). Useful to do before saving a model.
"""
wait_for_everyone()
def init_trackers(self, project_name: str, config: Optional[dict] = None, init_kwargs: Optional[dict] = {}):
"""
Initializes a run for all trackers stored in `self.log_with`, potentially with starting configurations
Args:
project_name (`str`):
The name of the project. All trackers will save their data based on this
config (`dict`, *optional*):
Optional starting configuration to be logged.
init_kwargs (`dict`, *optional*):
A nested dictionary of kwargs to be passed to a specific tracker's `__init__` function. Should be
formatted like this:
```python
{"wandb": {"tags": ["tag_a", "tag_b"]}}
```
"""
self.trackers = []
for tracker in self.log_with:
if issubclass(type(tracker), GeneralTracker):
# Custom trackers are already initialized
self.trackers.append(tracker)
else:
tracker_init = LOGGER_TYPE_TO_CLASS[str(tracker)]
if getattr(tracker_init, "requires_logging_directory"):
# We can skip this check since it was done in `__init__`
self.trackers.append(
tracker_init(project_name, self.logging_dir, **init_kwargs.get(str(tracker), {}))
)
else:
self.trackers.append(tracker_init(project_name, **init_kwargs.get(str(tracker), {})))
if config is not None:
for tracker in self.trackers:
tracker.store_init_configuration(config)
def log(self, values: dict, step: Optional[int] = None, log_kwargs: Optional[dict] = {}):
"""
Logs `values` to all stored trackers in `self.trackers`.
Args:
values (`dict`):
Values should be a dictionary-like object containing only types `int`, `float`, or `str`.