From 0c6bdc2c237ac071be99ac6f93ddfbc8bbcb8441 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Tue, 26 Jul 2022 18:14:29 +0530 Subject: [PATCH] enhancements and fixes for FSDP and DeepSpeed (#532) * checkpointing enhancements and fixes for FSDP and DeepSpeed * resolving comments 1. Adding deprecation args and warnings in launcher for FSDP 2. Handling old configs to work with new launcher args wrt FSDP. 3. Reverting changes to public methods in `checkpointing.py` and handling it in `Accelerator` 4. Explicitly writing the defaults of various FSDP options in `dataclasses` for readability. * fixes 1. FSDP wrapped model being added to the `_models`. 2. Not passing the env variables when args are None. * resolving comments * adding FSDP for all the collective operations * adding deepspeed and fsdp tests 1. Removes mrpc datafiles and directly relies on HF datasets as it was throwing `file not found` error when running from within `tests` folder. Updating `moke_dataloaders` as a result. 2. adding `test_performance.py`, `test_memory.py` and `test_checkpointing.py` for multi-gpu FSDP and DeepSpeed tests * reverting `mocked_dataloader` changes * adding FSDP tests * data files revert * excluding fsdp tests from `tests_core` * try 2 * adding time delay to avoid `torchrun` from crashing at times leading which causing flaky behaviour * reducing the time of tests * fixes * fix * fixes and reduce time further * reduce time further and minor fixes * adding a deepspeed basic e2e test for single gpu setup --- .github/workflows/test.yml | 1 + Makefile | 6 +- src/accelerate/accelerator.py | 84 ++++- src/accelerate/commands/config/cluster.py | 22 +- src/accelerate/commands/launch.py | 93 ++++- .../test_utils/scripts/test_checkpointing.py | 269 ++++++++++++++ .../scripts/test_peak_memory_usage.py | 258 ++++++++++++++ .../test_utils/scripts/test_performance.py | 231 ++++++++++++ src/accelerate/test_utils/testing.py | 8 + src/accelerate/utils/constants.py | 1 + src/accelerate/utils/dataclasses.py | 120 ++++++- src/accelerate/utils/operations.py | 24 +- src/accelerate/utils/other.py | 1 + tests/deepspeed/test_deepspeed.py | 231 +++++++++++- tests/fsdp/test_fsdp.py | 332 ++++++++++++++++++ 15 files changed, 1643 insertions(+), 38 deletions(-) create mode 100644 src/accelerate/test_utils/scripts/test_checkpointing.py create mode 100644 src/accelerate/test_utils/scripts/test_peak_memory_usage.py create mode 100644 src/accelerate/test_utils/scripts/test_performance.py create mode 100644 tests/fsdp/test_fsdp.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 09330eb63a2..806f0094d6a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,6 +16,7 @@ jobs: test_core, test_big_modeling, test_deepspeed, + test_fsdp, test_example_differences, test_checkpoint_step, test_checkpoint_epoch, diff --git a/Makefile b/Makefile index 4104f1bba02..9cae9d74d68 100644 --- a/Makefile +++ b/Makefile @@ -31,11 +31,15 @@ test_big_modeling: python -m pytest -s -v ./tests/test_big_modeling.py test_core: - python -m pytest -s -v ./tests/ --ignore=./tests/test_examples.py --ignore=./tests/deepspeed --ignore=./tests/test_big_modeling.py + python -m pytest -s -v ./tests/ --ignore=./tests/test_examples.py --ignore=./tests/deepspeed --ignore=./tests/test_big_modeling.py \ + --ignore=./tests/fsdp test_deepspeed: python -m pytest -s -v ./tests/deepspeed +test_fsdp: + python -m pytest -s -v ./tests/fsdp + test_examples: python -m pytest -s -v ./tests/test_examples.py diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index b47a2dc1539..acf960026b2 100644 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -32,6 +32,7 @@ from .state import AcceleratorState, GradientState from .tracking import LOGGER_TYPE_TO_CLASS, GeneralTracker, filter_trackers from .utils import ( + MODEL_NAME, DeepSpeedPlugin, DistributedDataParallelKwargs, DistributedType, @@ -572,7 +573,8 @@ def prepare(self, *args): if model_count > 1 and optimizer_present: raise ValueError( "For FSDP to work with multiple models (>1), " - "prepare must be called for all the models before optimizers are created" + "prepare must be called for all the models before optimizers are created. " + "Then pass the optimizers to the prepare call in the same order as corresponding models." ) elif model_count == 1 and optimizer_present: logger.warn( @@ -649,6 +651,7 @@ def prepare_model(self, model): ) if not fsdp_plugin.cpu_offload.offload_params: model.to(self.device) + self._models[-1] = model elif self.distributed_type == DistributedType.MULTI_CPU: kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {} model = torch.nn.parallel.DistributedDataParallel(model, **kwargs) @@ -1098,9 +1101,44 @@ def save_state(self, output_dir: str): output_dir = os.path.expanduser(output_dir) os.makedirs(output_dir, exist_ok=True) logger.info(f"Saving current state to {output_dir}") - weights = [self.get_state_dict(m, unwrap=False) for m in self._models] + + # Save the models taking care of FSDP and DeepSpeed nuances + weights = [] + for i, model in enumerate(self._models): + if self.distributed_type == DistributedType.FSDP: + logger.info("Saving FSDP model") + self.state.fsdp_plugin.save_model(self, model, output_dir, i) + logger.info(f"FSDP Model saved to output dir {output_dir}") + elif self.distributed_type == DistributedType.DEEPSPEED: + logger.info("Saving DeepSpeed Model and Optimizer") + ckpt_id = f"{MODEL_NAME}" if i == 0 else f"{MODEL_NAME}_{i}" + model.save_checkpoint(output_dir, ckpt_id) + logger.info(f"DeepSpeed Model and Optimizer saved to output dir {os.path.join(output_dir, ckpt_id)}") + else: + weights.append(self.get_state_dict(model, unwrap=False)) + + # Save the optimizers taking care of FSDP and DeepSpeed nuances + optimizers = [] + if self.distributed_type == DistributedType.FSDP: + for opt in self._optimizers: + logger.info("Saving FSDP Optimizer") + self.state.fsdp_plugin.save_optimizer(self, opt, self._models[i], output_dir, i) + logger.info(f"FSDP Optimizer saved to output dir {output_dir}") + elif self.distributed_type != DistributedType.DEEPSPEED: + optimizers = self._optimizers + + # Save the lr schedulers taking care of DeepSpeed nuances + schedulers = [] + if self.distributed_type == DistributedType.DEEPSPEED: + for i, scheduler in enumerate(self._schedulers): + if isinstance(scheduler, DeepSpeedSchedulerWrapper): + continue + schedulers.append(scheduler) + else: + schedulers = self._schedulers + save_location = save_accelerator_state( - output_dir, weights, self._optimizers, self._schedulers, self.state.process_index, self.scaler + output_dir, weights, optimizers, schedulers, self.state.process_index, self.scaler ) for i, obj in enumerate(self._custom_objects): save_custom_state(obj, output_dir, i) @@ -1119,9 +1157,43 @@ def load_state(self, input_dir: str): if not os.path.isdir(input_dir): raise ValueError(f"Tried to find {input_dir} but folder does not exist") logger.info(f"Loading states from {input_dir}") - load_accelerator_state( - input_dir, self._models, self._optimizers, self._schedulers, self.state.process_index, self.scaler - ) + + # Load the models taking care of FSDP and DeepSpeed nuances + models = [] + for i, model in enumerate(self._models): + if self.distributed_type == DistributedType.FSDP: + logger.info("Loading FSDP model") + self.state.fsdp_plugin.load_model(self, model, input_dir, i) + logger.info(f"FSDP Model loaded from input dir {input_dir}") + elif self.distributed_type == DistributedType.DEEPSPEED: + logger.info("Loading DeepSpeed Model and Optimizer") + ckpt_id = f"{MODEL_NAME}" if i == 0 else f"{MODEL_NAME}_{i}" + model.load_checkpoint(input_dir, ckpt_id) + logger.info(f"DeepSpeed Model and Optimizer loaded from input dir {os.path.join(input_dir, ckpt_id)}") + else: + models.append(model) + + # Load the optimizers taking care of FSDP and DeepSpeed nuances + optimizers = [] + if self.distributed_type == DistributedType.FSDP: + for i, opt in enumerate(self._optimizers): + logger.info("Loading FSDP Optimizer") + self.state.fsdp_plugin.load_optimizer(self, opt, self._models[i], input_dir, i) + logger.info(f"FSDP Optimizer loaded from input dir {input_dir}") + elif self.distributed_type != DistributedType.DEEPSPEED: + optimizers = self._optimizers + + # Load the lr schedulers taking care of DeepSpeed nuances + schedulers = [] + if self.distributed_type == DistributedType.DEEPSPEED: + for i, scheduler in enumerate(self._schedulers): + if isinstance(scheduler, DeepSpeedSchedulerWrapper): + continue + schedulers.append(scheduler) + else: + schedulers = self._schedulers + + load_accelerator_state(input_dir, models, optimizers, schedulers, self.state.process_index, self.scaler) custom_checkpoints = [f for f in os.listdir(input_dir) if "custom_checkpoint" in f] if len(custom_checkpoints) != len(self._custom_objects): err = "Warning! Number of found checkpoints does not match the number of registered objects:" diff --git a/src/accelerate/commands/config/cluster.py b/src/accelerate/commands/config/cluster.py index fd345fa1ccc..ba677e3eb9b 100644 --- a/src/accelerate/commands/config/cluster.py +++ b/src/accelerate/commands/config/cluster.py @@ -20,6 +20,7 @@ FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_SHARDING_STRATEGY, + FSDP_STATE_DICT_TYPE, ) from .config_args import ClusterConfig from .config_utils import _ask_field, _convert_distributed_mode, _convert_yes_no_to_bool @@ -210,12 +211,12 @@ def get_cluster_input(): for i, strategy in enumerate(FSDP_SHARDING_STRATEGY): sharding_strategy_query += f"[{i+1}] {strategy}, " sharding_strategy_query = sharding_strategy_query[:-2] + ")? [1]: " - fsdp_config["sharding_strategy"] = _ask_field( + fsdp_config["fsdp_sharding_strategy"] = _ask_field( sharding_strategy_query, lambda x: int(x), default=1, ) - fsdp_config["offload_params"] = _ask_field( + fsdp_config["fsdp_offload_params"] = _ask_field( "Do you want to offload parameters and gradients to CPU? [yes/NO]: ", _convert_yes_no_to_bool, default=False, @@ -228,15 +229,15 @@ def get_cluster_input(): fsdp_config["fsdp_auto_wrap_policy"] = _ask_field( fsdp_wrap_query, lambda x: FSDP_AUTO_WRAP_POLICY[int(x)], - default=FSDP_AUTO_WRAP_POLICY[0], + default="TRANSFORMER_BASED_WRAP", ) if fsdp_config["fsdp_auto_wrap_policy"] == FSDP_AUTO_WRAP_POLICY[0]: - fsdp_config["transformer_layer_cls_to_wrap"] = _ask_field( + fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = _ask_field( "What is the transformer layer class name (case-sensitive) to wrap ,e.g, `BertLayer`, `GPTJBlock`, `T5Block` ...? : ", lambda x: str(x), ) elif fsdp_config["fsdp_auto_wrap_policy"] == FSDP_AUTO_WRAP_POLICY[1]: - fsdp_config["min_num_params"] = _ask_field( + fsdp_config["fsdp_min_num_params"] = _ask_field( "What should be your FSDP's minimum number of parameters for Default Auto Wrapping Policy? [1e8]: ", lambda x: int(x), default=1e8, @@ -248,7 +249,16 @@ def get_cluster_input(): fsdp_config["fsdp_backward_prefetch_policy"] = _ask_field( fsdp_backward_prefetch_query, lambda x: FSDP_BACKWARD_PREFETCH[int(x)], - default=FSDP_BACKWARD_PREFETCH[0], + default="BACKWARD_PRE", + ) + fsdp_state_dict_type_query = "What should be your FSDP's state dict type (" + for i, state_dict_type in enumerate(FSDP_STATE_DICT_TYPE): + fsdp_state_dict_type_query += f"[{i}] {state_dict_type}, " + fsdp_state_dict_type_query = fsdp_state_dict_type_query[:-2] + ")? [0]: " + fsdp_config["fsdp_state_dict_type"] = _ask_field( + fsdp_state_dict_type_query, + lambda x: FSDP_STATE_DICT_TYPE[int(x)], + default="FULL_STATE_DICT", ) if distributed_type == DistributedType.TPU: diff --git a/src/accelerate/commands/launch.py b/src/accelerate/commands/launch.py index e6a23e46df9..514e77d0517 100644 --- a/src/accelerate/commands/launch.py +++ b/src/accelerate/commands/launch.py @@ -148,19 +148,19 @@ def launch_command_parser(subparsers=None): help="Whether to use fsdp.", ) parser.add_argument( - "--offload_params", + "--fsdp_offload_params", default="false", type=str, help="Decides Whether (true|false) to offload parameters and gradients to CPU. (useful only when `use_fsdp` flag is passed).", ) parser.add_argument( - "--min_num_params", + "--fsdp_min_num_params", type=int, default=1e8, help="FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `use_fsdp` flag is passed).", ) parser.add_argument( - "--sharding_strategy", + "--fsdp_sharding_strategy", type=int, default=1, help="FSDP's Sharding Strategy. (useful only when `use_fsdp` flag is passed).", @@ -172,7 +172,7 @@ def launch_command_parser(subparsers=None): help="FSDP's auto wrap policy. (useful only when `use_fsdp` flag is passed).", ) parser.add_argument( - "--transformer_layer_cls_to_wrap", + "--fsdp_transformer_layer_cls_to_wrap", default=None, type=str, help="Transformer layer class name (case-sensitive) to wrap ,e.g, `BertLayer`, `GPTJBlock`, `T5Block` .... " @@ -184,6 +184,36 @@ def launch_command_parser(subparsers=None): type=str, help="FSDP's backward prefetch policy. (useful only when `use_fsdp` flag is passed).", ) + parser.add_argument( + "--fsdp_state_dict_type", + default=None, + type=str, + help="FSDP's state dict type. (useful only when `use_fsdp` flag is passed).", + ) + parser.add_argument( + "--offload_params", + default=None, + type=str, + help="This argument is deprecated. Use `fsdp_offload_params` instead.", + ) + parser.add_argument( + "--min_num_params", + type=int, + default=None, + help="This argument is deprecated. Use `fsdp_min_num_params` instead.", + ) + parser.add_argument( + "--sharding_strategy", + type=int, + default=None, + help="This argument is deprecated. Use `fsdp_sharding_strategy` instead.", + ) + parser.add_argument( + "--transformer_layer_cls_to_wrap", + default=None, + type=str, + help="This argument is deprecated. Use `fsdp_transformer_layer_cls_to_wrap` instead.", + ) parser.add_argument( "--tpu", default=False, action="store_true", help="Whether or not this should launch a TPU training." ) @@ -360,13 +390,51 @@ def multi_gpu_launcher(args): current_env["MIXED_PRECISION"] = str(mixed_precision) if args.use_fsdp: + if args.sharding_strategy is not None: + warnings.warn( + "`sharding_strategy` is deprecated and will be removed in version 0.13.0 of 🤗 Accelerate. Use" + " `fsdp_sharding_strategy` instead", + FutureWarning, + ) + args.fsdp_sharding_strategy = args.sharding_strategy + + if args.offload_params is not None: + warnings.warn( + "`offload_params` is deprecated and will be removed in version 0.13.0 of 🤗 Accelerate. Use" + " `fsdp_offload_params` instead", + FutureWarning, + ) + args.fsdp_offload_params = args.offload_params + + if args.min_num_params is not None: + warnings.warn( + "`min_num_params` is deprecated and will be removed in version 0.13.0 of 🤗 Accelerate. Use" + " `fsdp_min_num_params` instead", + FutureWarning, + ) + args.fsdp_min_num_params = args.min_num_params + + if args.transformer_layer_cls_to_wrap is not None: + warnings.warn( + "`transformer_layer_cls_to_wrap` is deprecated and will be removed in version 0.13.0 of 🤗 Accelerate. Use" + " `fsdp_transformer_layer_cls_to_wrap` instead", + FutureWarning, + ) + args.fsdp_transformer_layer_cls_to_wrap = args.transformer_layer_cls_to_wrap + current_env["USE_FSDP"] = "true" - current_env["FSDP_AUTO_WRAP_POLICY"] = str(args.fsdp_auto_wrap_policy) - current_env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = str(args.transformer_layer_cls_to_wrap) - current_env["FSDP_OFFLOAD_PARAMS"] = str(args.offload_params).lower() - current_env["FSDP_MIN_NUM_PARAMS"] = str(args.min_num_params) - current_env["FSDP_SHARDING_STRATEGY"] = str(args.sharding_strategy) - current_env["FSDP_BACKWARD_PREFETCH"] = str(args.fsdp_backward_prefetch_policy) + current_env["FSDP_SHARDING_STRATEGY"] = str(args.fsdp_sharding_strategy) + current_env["FSDP_OFFLOAD_PARAMS"] = str(args.fsdp_offload_params).lower() + current_env["FSDP_MIN_NUM_PARAMS"] = str(args.fsdp_min_num_params) + if args.fsdp_auto_wrap_policy is not None: + current_env["FSDP_AUTO_WRAP_POLICY"] = str(args.fsdp_auto_wrap_policy) + if args.fsdp_transformer_layer_cls_to_wrap is not None: + current_env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = str(args.fsdp_transformer_layer_cls_to_wrap) + if args.fsdp_backward_prefetch_policy is not None: + current_env["FSDP_BACKWARD_PREFETCH"] = str(args.fsdp_backward_prefetch_policy) + if args.fsdp_state_dict_type is not None: + current_env["FSDP_STATE_DICT_TYPE"] = str(args.fsdp_state_dict_type) + current_env["OMP_NUM_THREADS"] = str(args.num_cpu_threads_per_process) process = subprocess.Popen(cmd, env=current_env) process.wait() @@ -682,7 +750,10 @@ def launch_command(args): if getattr(args, k) is None: setattr(args, k, defaults.deepspeed_config[k]) for k in defaults.fsdp_config: - setattr(args, k, defaults.fsdp_config[k]) + arg_to_set = k + if "fsdp" not in arg_to_set: + arg_to_set = "fsdp_" + arg_to_set + setattr(args, arg_to_set, defaults.fsdp_config[k]) continue # Those args are handled separately diff --git a/src/accelerate/test_utils/scripts/test_checkpointing.py b/src/accelerate/test_utils/scripts/test_checkpointing.py new file mode 100644 index 00000000000..cde602dfa63 --- /dev/null +++ b/src/accelerate/test_utils/scripts/test_checkpointing.py @@ -0,0 +1,269 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os + +import torch +from torch.optim import AdamW +from torch.utils.data import DataLoader + +import evaluate +from accelerate import Accelerator, DistributedType +from accelerate.utils.deepspeed import DummyOptim, DummyScheduler +from datasets import load_dataset +from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed + + +MAX_GPU_BATCH_SIZE = 16 +EVAL_BATCH_SIZE = 32 + + +def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, model_name: str = "bert-base-cased"): + """ + Creates a set of `DataLoader`s for the `glue` dataset. + + Args: + accelerator (`Accelerator`): + An `Accelerator` object + batch_size (`int`, *optional*): + The batch size for the train and validation DataLoaders. + model_name (`str`, *optional*): + """ + tokenizer = AutoTokenizer.from_pretrained(model_name) + datasets = load_dataset("glue", "mrpc") + + def tokenize_function(examples): + # max_length=None => use the model max length (it's actually the default) + outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) + return outputs + + # Apply the method we just defined to all the examples in all the splits of the dataset + tokenized_datasets = datasets.map( + tokenize_function, batched=True, remove_columns=["idx", "sentence1", "sentence2"], load_from_cache_file=False + ) + + # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the + # transformers library + tokenized_datasets = tokenized_datasets.rename_column("label", "labels") + + def collate_fn(examples): + # On TPU it's best to pad everything to the same length or training will be very slow. + if accelerator.distributed_type == DistributedType.TPU: + return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt") + return tokenizer.pad(examples, padding="longest", return_tensors="pt") + + # Instantiate dataloaders. + train_dataloader = DataLoader( + tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size + ) + eval_dataloader = DataLoader( + tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE + ) + + return train_dataloader, eval_dataloader + + +def evaluation_loop(accelerator, model, eval_dataloader, metric): + model.eval() + samples_seen = 0 + for step, batch in enumerate(eval_dataloader): + # We could avoid this line since we set the accelerator with `device_placement=True`. + batch.to(accelerator.device) + with torch.no_grad(): + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) + # It is slightly faster to call this once, than multiple times + predictions, references = accelerator.gather( + (predictions, batch["labels"]) + ) # If we are in a multiprocess environment, the last batch has duplicates + if accelerator.use_distributed: + if step == len(eval_dataloader) - 1: + predictions = predictions[: len(eval_dataloader.dataset) - samples_seen] + references = references[: len(eval_dataloader.dataset) - samples_seen] + else: + samples_seen += references.shape[0] + metric.add_batch( + predictions=predictions, + references=references, + ) + + eval_metric = metric.compute() + return eval_metric["accuracy"] + + +def training_function(config, args): + # Initialize accelerator + accelerator = Accelerator() + + # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs + lr = config["lr"] + num_epochs = int(config["num_epochs"]) + seed = int(config["seed"]) + batch_size = int(config["batch_size"]) + model_name = args.model_name_or_path + + set_seed(seed) + train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size, model_name) + + # Instantiate the model (we build the model here so that the seed also control new weights initialization) + model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True) + + # Instantiate optimizer + optimizer_cls = ( + AdamW + if accelerator.state.deepspeed_plugin is None + or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config + else DummyOptim + ) + optimizer = optimizer_cls(params=model.parameters(), lr=lr) + + if accelerator.state.deepspeed_plugin is not None: + gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[ + "gradient_accumulation_steps" + ] + else: + gradient_accumulation_steps = 1 + max_training_steps = (len(train_dataloader) * num_epochs) // gradient_accumulation_steps + + # Instantiate scheduler + if ( + accelerator.state.deepspeed_plugin is None + or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config + ): + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=max_training_steps, + ) + else: + lr_scheduler = DummyScheduler(optimizer, total_num_steps=max_training_steps, warmup_num_steps=0) + + # Prepare everything + # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the + # prepare method. + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + ) + + # We need to keep track of how many total steps we have iterated over + overall_step = 0 + # We also need to keep track of the stating epoch so files are named properly + starting_epoch = 0 + metric = evaluate.load("glue", "mrpc") + ending_epoch = num_epochs + + if args.partial_train_epoch is not None: + ending_epoch = args.partial_train_epoch + + if args.resume_from_checkpoint: + accelerator.load_state(args.resume_from_checkpoint) + epoch_string = args.resume_from_checkpoint.split("epoch_")[1] + state_epoch_num = "" + for char in epoch_string: + if char.isdigit(): + state_epoch_num += char + else: + break + starting_epoch = int(state_epoch_num) + 1 + accuracy = evaluation_loop(accelerator, model, eval_dataloader, metric) + accelerator.print("resumed checkpoint performance:", accuracy) + accelerator.print("resumed checkpoint's scheduler's lr:", lr_scheduler.get_lr()[0]) + accelerator.print("resumed optimizers's lr:", optimizer.param_groups[0]["lr"]) + with open(os.path.join(args.output_dir, f"state_{starting_epoch-1}.json"), "r") as f: + resumed_state = json.load(f) + assert resumed_state["accuracy"] == accuracy, "Accuracy mismatch, loading from checkpoint failed" + assert ( + resumed_state["lr"] == lr_scheduler.get_lr()[0] + ), "Scheduler learning rate mismatch, loading from checkpoint failed" + assert ( + resumed_state["optimizer_lr"] == optimizer.param_groups[0]["lr"] + ), "Optimizer learning rate mismatch, loading from checkpoint failed" + assert resumed_state["epoch"] == starting_epoch - 1, "Epoch mismatch, loading from checkpoint failed" + return + + # Now we train the model + state = {} + for epoch in range(starting_epoch, ending_epoch): + model.train() + for step, batch in enumerate(train_dataloader): + outputs = model(**batch) + loss = outputs.loss + loss = loss / gradient_accumulation_steps + accelerator.backward(loss) + if step % gradient_accumulation_steps == 0: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + overall_step += 1 + output_dir = f"epoch_{epoch}" + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + accuracy = evaluation_loop(accelerator, model, eval_dataloader, metric) + state["accuracy"] = accuracy + state["lr"] = lr_scheduler.get_lr()[0] + state["optimizer_lr"] = optimizer.param_groups[0]["lr"] + state["epoch"] = epoch + state["step"] = overall_step + accelerator.print(f"epoch {epoch}:", state) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + with open(os.path.join(args.output_dir, f"state_{epoch}.json"), "w") as f: + json.dump(state, f) + + +def main(): + parser = argparse.ArgumentParser(description="Simple example of training script tracking peak GPU memory usage.") + parser.add_argument( + "--model_name_or_path", + type=str, + default="bert-base-cased", + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=False, + ) + parser.add_argument( + "--output_dir", + type=str, + default=".", + help="Optional save directory where all checkpoint folders will be stored. Default is the current working directory.", + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help="If the training should continue from a checkpoint folder.", + ) + parser.add_argument( + "--partial_train_epoch", + type=int, + default=None, + help="If passed, the training will stop after this number of epochs.", + ) + parser.add_argument( + "--num_epochs", + type=int, + default=2, + help="Number of train epochs.", + ) + args = parser.parse_args() + config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16} + + training_function(config, args) + + +if __name__ == "__main__": + main() diff --git a/src/accelerate/test_utils/scripts/test_peak_memory_usage.py b/src/accelerate/test_utils/scripts/test_peak_memory_usage.py new file mode 100644 index 00000000000..7bb5ca3bf41 --- /dev/null +++ b/src/accelerate/test_utils/scripts/test_peak_memory_usage.py @@ -0,0 +1,258 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import gc +import json +import os + +import torch +from torch.optim import AdamW +from torch.utils.data import DataLoader + +from accelerate import Accelerator, DistributedType +from accelerate.utils.deepspeed import DummyOptim, DummyScheduler +from datasets import load_dataset +from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed + + +MAX_GPU_BATCH_SIZE = 16 +EVAL_BATCH_SIZE = 32 + + +# Converting Bytes to Megabytes +def b2mb(x): + return int(x / 2**20) + + +# This context manager is used to track the peak memory usage of the process +class TorchTracemalloc: + def __enter__(self): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero + self.begin = torch.cuda.memory_allocated() + return self + + def __exit__(self, *exc): + gc.collect() + torch.cuda.empty_cache() + self.end = torch.cuda.memory_allocated() + self.peak = torch.cuda.max_memory_allocated() + self.used = b2mb(self.end - self.begin) + self.peaked = b2mb(self.peak - self.begin) + # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}") + + +def get_dataloaders( + accelerator: Accelerator, + batch_size: int = 16, + model_name: str = "bert-base-cased", + n_train: int = 320, + n_val: int = 160, +): + """ + Creates a set of `DataLoader`s for the `glue` dataset. + + Args: + accelerator (`Accelerator`): + An `Accelerator` object + batch_size (`int`, *optional*): + The batch size for the train and validation DataLoaders. + model_name (`str`, *optional*): + The name of the model to use. + n_train (`int`, *optional*): + The number of training examples to use. + n_val (`int`, *optional*): + The number of validation examples to use. + """ + tokenizer = AutoTokenizer.from_pretrained(model_name) + datasets = load_dataset( + "glue", "mrpc", split={"train": f"train[:{n_train}]", "validation": f"validation[:{n_val}]"} + ) + + def tokenize_function(examples): + # max_length=None => use the model max length (it's actually the default) + outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) + return outputs + + # Apply the method we just defined to all the examples in all the splits of the dataset + tokenized_datasets = datasets.map( + tokenize_function, batched=True, remove_columns=["idx", "sentence1", "sentence2"], load_from_cache_file=False + ) + + # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the + # transformers library + tokenized_datasets = tokenized_datasets.rename_column("label", "labels") + + def collate_fn(examples): + # On TPU it's best to pad everything to the same length or training will be very slow. + if accelerator.distributed_type == DistributedType.TPU: + return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt") + return tokenizer.pad(examples, padding="longest", return_tensors="pt") + + # Instantiate dataloaders. + train_dataloader = DataLoader( + tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size + ) + eval_dataloader = DataLoader( + tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE + ) + + return train_dataloader, eval_dataloader + + +def training_function(config, args): + # Initialize accelerator + accelerator = Accelerator() + + # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs + lr = config["lr"] + num_epochs = int(config["num_epochs"]) + seed = int(config["seed"]) + batch_size = int(config["batch_size"]) + model_name = args.model_name_or_path + + set_seed(seed) + train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size, model_name, args.n_train, args.n_val) + + # Instantiate the model (we build the model here so that the seed also control new weights initialization) + model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True) + + # Instantiate optimizer + optimizer_cls = ( + AdamW + if accelerator.state.deepspeed_plugin is None + or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config + else DummyOptim + ) + optimizer = optimizer_cls(params=model.parameters(), lr=lr) + + if accelerator.state.deepspeed_plugin is not None: + gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[ + "gradient_accumulation_steps" + ] + else: + gradient_accumulation_steps = 1 + max_training_steps = (len(train_dataloader) * num_epochs) // gradient_accumulation_steps + + # Instantiate scheduler + if ( + accelerator.state.deepspeed_plugin is None + or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config + ): + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=max_training_steps, + ) + else: + lr_scheduler = DummyScheduler(optimizer, total_num_steps=max_training_steps, warmup_num_steps=0) + + # Prepare everything + # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the + # prepare method. + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + ) + + # We need to keep track of how many total steps we have iterated over + overall_step = 0 + # We also need to keep track of the stating epoch so files are named properly + starting_epoch = 0 + + # Now we train the model + train_total_peak_memory = {} + for epoch in range(starting_epoch, num_epochs): + with TorchTracemalloc() as tracemalloc: + model.train() + for step, batch in enumerate(train_dataloader): + outputs = model(**batch) + loss = outputs.loss + loss = loss / gradient_accumulation_steps + accelerator.backward(loss) + if step % gradient_accumulation_steps == 0: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + overall_step += 1 + + # Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage + accelerator.print("Memory before entering the train : {}".format(b2mb(tracemalloc.begin))) + accelerator.print("Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.used)) + accelerator.print("Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.peaked)) + accelerator.print( + "Total Peak Memory consumed during the train (max): {}".format( + tracemalloc.peaked + b2mb(tracemalloc.begin) + ) + ) + train_total_peak_memory[f"epoch-{epoch}"] = tracemalloc.peaked + b2mb(tracemalloc.begin) + if args.peak_memory_upper_bound is not None: + assert ( + train_total_peak_memory[f"epoch-{epoch}"] <= args.peak_memory_upper_bound + ), "Peak memory usage exceeded the upper bound" + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + with open(os.path.join(args.output_dir, "peak_memory_utilization.json"), "w") as f: + json.dump(train_total_peak_memory, f) + + +def main(): + parser = argparse.ArgumentParser(description="Simple example of training script tracking peak GPU memory usage.") + parser.add_argument( + "--model_name_or_path", + type=str, + default="bert-base-cased", + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=False, + ) + parser.add_argument( + "--output_dir", + type=str, + default=".", + help="Optional save directory where all checkpoint folders will be stored. Default is the current working directory.", + ) + parser.add_argument( + "--peak_memory_upper_bound", + type=float, + default=None, + help="The upper bound of peak memory usage in MB. If set, the training will throw an error if the peak memory usage exceeds this value.", + ) + parser.add_argument( + "--n_train", + type=int, + default=320, + help="Number of training examples to use.", + ) + parser.add_argument( + "--n_val", + type=int, + default=160, + help="Number of validation examples to use.", + ) + parser.add_argument( + "--num_epochs", + type=int, + default=1, + help="Number of train epochs.", + ) + args = parser.parse_args() + config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16} + training_function(config, args) + + +if __name__ == "__main__": + main() diff --git a/src/accelerate/test_utils/scripts/test_performance.py b/src/accelerate/test_utils/scripts/test_performance.py new file mode 100644 index 00000000000..324a1854ecb --- /dev/null +++ b/src/accelerate/test_utils/scripts/test_performance.py @@ -0,0 +1,231 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os + +import torch +from torch.optim import AdamW +from torch.utils.data import DataLoader + +import evaluate +from accelerate import Accelerator, DistributedType +from accelerate.utils.deepspeed import DummyOptim, DummyScheduler +from datasets import load_dataset +from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed + + +MAX_GPU_BATCH_SIZE = 16 +EVAL_BATCH_SIZE = 32 + + +def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, model_name: str = "bert-base-cased"): + """ + Creates a set of `DataLoader`s for the `glue` dataset. + + Args: + accelerator (`Accelerator`): + An `Accelerator` object + batch_size (`int`, *optional*): + The batch size for the train and validation DataLoaders. + model_name (`str`, *optional*): + """ + tokenizer = AutoTokenizer.from_pretrained(model_name) + datasets = load_dataset("glue", "mrpc") + + def tokenize_function(examples): + # max_length=None => use the model max length (it's actually the default) + outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) + return outputs + + # Apply the method we just defined to all the examples in all the splits of the dataset + tokenized_datasets = datasets.map( + tokenize_function, batched=True, remove_columns=["idx", "sentence1", "sentence2"], load_from_cache_file=False + ) + + # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the + # transformers library + tokenized_datasets = tokenized_datasets.rename_column("label", "labels") + + def collate_fn(examples): + # On TPU it's best to pad everything to the same length or training will be very slow. + if accelerator.distributed_type == DistributedType.TPU: + return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt") + return tokenizer.pad(examples, padding="longest", return_tensors="pt") + + # Instantiate dataloaders. + train_dataloader = DataLoader( + tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size + ) + eval_dataloader = DataLoader( + tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE + ) + + return train_dataloader, eval_dataloader + + +def training_function(config, args): + # Initialize accelerator + accelerator = Accelerator() + + # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs + lr = config["lr"] + num_epochs = int(config["num_epochs"]) + seed = int(config["seed"]) + batch_size = int(config["batch_size"]) + model_name = args.model_name_or_path + + set_seed(seed) + train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size, model_name) + + # Instantiate the model (we build the model here so that the seed also control new weights initialization) + model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True) + + # Instantiate optimizer + optimizer_cls = ( + AdamW + if accelerator.state.deepspeed_plugin is None + or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config + else DummyOptim + ) + optimizer = optimizer_cls(params=model.parameters(), lr=lr) + + if accelerator.state.deepspeed_plugin is not None: + gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[ + "gradient_accumulation_steps" + ] + else: + gradient_accumulation_steps = 1 + max_training_steps = (len(train_dataloader) * num_epochs) // gradient_accumulation_steps + + # Instantiate scheduler + if ( + accelerator.state.deepspeed_plugin is None + or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config + ): + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=max_training_steps, + ) + else: + lr_scheduler = DummyScheduler(optimizer, total_num_steps=max_training_steps, warmup_num_steps=0) + + # Prepare everything + # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the + # prepare method. + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + ) + + # We need to keep track of how many total steps we have iterated over + overall_step = 0 + # We also need to keep track of the stating epoch so files are named properly + starting_epoch = 0 + + # Now we train the model + metric = evaluate.load("glue", "mrpc") + best_performance = 0 + performance_metric = {} + for epoch in range(starting_epoch, num_epochs): + model.train() + for step, batch in enumerate(train_dataloader): + outputs = model(**batch) + loss = outputs.loss + loss = loss / gradient_accumulation_steps + accelerator.backward(loss) + if step % gradient_accumulation_steps == 0: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + overall_step += 1 + + model.eval() + samples_seen = 0 + for step, batch in enumerate(eval_dataloader): + # We could avoid this line since we set the accelerator with `device_placement=True`. + batch.to(accelerator.device) + with torch.no_grad(): + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) + # It is slightly faster to call this once, than multiple times + predictions, references = accelerator.gather( + (predictions, batch["labels"]) + ) # If we are in a multiprocess environment, the last batch has duplicates + if accelerator.use_distributed: + if step == len(eval_dataloader) - 1: + predictions = predictions[: len(eval_dataloader.dataset) - samples_seen] + references = references[: len(eval_dataloader.dataset) - samples_seen] + else: + samples_seen += references.shape[0] + metric.add_batch( + predictions=predictions, + references=references, + ) + + eval_metric = metric.compute() + # Use accelerator.print to print only on the main process. + accelerator.print(f"epoch {epoch}:", eval_metric) + performance_metric[f"epoch-{epoch}"] = eval_metric["accuracy"] + + if best_performance < eval_metric["accuracy"]: + best_performance = eval_metric["accuracy"] + + if args.performance_lower_bound is not None: + assert ( + args.performance_lower_bound <= best_performance + ), f"Best performance metric {best_performance} is lower than the lower bound {args.performance_lower_bound}" + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: + json.dump(performance_metric, f) + + +def main(): + parser = argparse.ArgumentParser(description="Simple example of training script tracking peak GPU memory usage.") + parser.add_argument( + "--model_name_or_path", + type=str, + default="bert-base-cased", + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=False, + ) + parser.add_argument( + "--output_dir", + type=str, + default=".", + help="Optional save directory where all checkpoint folders will be stored. Default is the current working directory.", + ) + parser.add_argument( + "--performance_lower_bound", + type=float, + default=None, + help="Optional lower bound for the performance metric. If set, the training will throw error when the performance metric drops below this value.", + ) + parser.add_argument( + "--num_epochs", + type=int, + default=3, + help="Number of train epochs.", + ) + args = parser.parse_args() + config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16} + training_function(config, args) + + +if __name__ == "__main__": + main() diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py index 0b3145070a7..b5cccc1ff6a 100644 --- a/src/accelerate/test_utils/testing.py +++ b/src/accelerate/test_utils/testing.py @@ -32,6 +32,7 @@ is_comet_ml_available, is_deepspeed_available, is_tensorboard_available, + is_torch_version, is_tpu_available, is_wandb_available, ) @@ -108,6 +109,13 @@ def require_deepspeed(test_case): return unittest.skipUnless(is_deepspeed_available(), "test requires DeepSpeed")(test_case) +def require_fsdp(test_case): + """ + Decorator marking a test that requires FSDP installed. These tests are skipped when FSDP isn't installed + """ + return unittest.skipUnless(is_torch_version(">=", "1.12.0"), "test requires torch version >= 1.12.0")(test_case) + + def require_tensorboard(test_case): """ Decorator marking a test that requires tensorboard installed. These tests are skipped when tensorboard isn't diff --git a/src/accelerate/utils/constants.py b/src/accelerate/utils/constants.py index 00556cad6d5..4e7c71853d2 100644 --- a/src/accelerate/utils/constants.py +++ b/src/accelerate/utils/constants.py @@ -27,6 +27,7 @@ FSDP_SHARDING_STRATEGY = ["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD"] FSDP_AUTO_WRAP_POLICY = ["TRANSFORMER_BASED_WRAP", "SIZE_BASED_WRAP", "NO_WRAP"] FSDP_BACKWARD_PREFETCH = ["BACKWARD_PRE", "BACKWARD_POST", "NO_PREFETCH"] +FSDP_STATE_DICT_TYPE = ["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich"] STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 8c8bdf10616..350d953207e 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -28,7 +28,7 @@ import torch -from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH +from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_STATE_DICT_TYPE, MODEL_NAME, OPTIMIZER_NAME class KwargsHandler: @@ -455,8 +455,28 @@ class FullyShardedDataParallelPlugin: metadata={"help": "A list of modules to ignore for FSDP."}, ) + state_dict_type: "typing.Any" = field( + default=None, + metadata={ + "help": "FSDP State Dict Type of type `torch.distributed.fsdp.fully_sharded_data_parallel.StateDictType`" + }, + ) + + state_dict_config: "typing.Any" = field( + default=None, + metadata={ + "help": "FSDP State Dict Config of type `torch.distributed.fsdp.fully_sharded_data_parallel.StateDictConfig`" + }, + ) + def __post_init__(self): - from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, CPUOffload, ShardingStrategy + from torch.distributed.fsdp.fully_sharded_data_parallel import ( + BackwardPrefetch, + CPUOffload, + ShardingStrategy, + StateDictType, + _state_dict_type_to_config, + ) if self.sharding_strategy is None: self.sharding_strategy = ShardingStrategy(int(os.environ.get("FSDP_SHARDING_STRATEGY", 1))) @@ -468,10 +488,21 @@ def __post_init__(self): self.cpu_offload = CPUOffload(offload_params=False) if self.backward_prefetch is None: - prefetch_policy = os.environ.get("FSDP_BACKWARD_PREFETCH", FSDP_BACKWARD_PREFETCH[-1]) + prefetch_policy = os.environ.get("FSDP_BACKWARD_PREFETCH", "NO_PREFETCH") if prefetch_policy != FSDP_BACKWARD_PREFETCH[-1]: self.backward_prefetch = BackwardPrefetch(FSDP_BACKWARD_PREFETCH.index(prefetch_policy) + 1) + if self.state_dict_type is None: + state_dict_type_policy = os.environ.get("FSDP_STATE_DICT_TYPE", "FULL_STATE_DICT") + self.state_dict_type = StateDictType(FSDP_STATE_DICT_TYPE.index(state_dict_type_policy) + 1) + + if self.state_dict_type == StateDictType.FULL_STATE_DICT: + self.state_dict_config = _state_dict_type_to_config[self.state_dict_type]( + offload_to_cpu=True, rank0_only=True + ) + else: + self.state_dict_config = _state_dict_type_to_config[self.state_dict_type]() + @staticmethod def get_module_class_from_name(module, name): """ @@ -496,7 +527,7 @@ def set_auto_wrap_policy(self, model): from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy if self.auto_wrap_policy is None: - auto_wrap_policy = os.environ.get("FSDP_AUTO_WRAP_POLICY", FSDP_AUTO_WRAP_POLICY[-1]) + auto_wrap_policy = os.environ.get("FSDP_AUTO_WRAP_POLICY", "NO_WRAP") if auto_wrap_policy == FSDP_AUTO_WRAP_POLICY[0]: transformer_cls_to_wrap = os.environ.get("FSDP_TRANSFORMER_CLS_TO_WRAP", "") transformer_cls_to_wrap = FullyShardedDataParallelPlugin.get_module_class_from_name( @@ -527,3 +558,84 @@ def set_mixed_precision(self, mixed_precision): if self.mixed_precision_policy is None: self.mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype) + + def save_model(self, accelerator, model, output_dir, model_index=0): + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType + + if self.state_dict_type == StateDictType.FULL_STATE_DICT: + with FSDP.state_dict_type(model, self.state_dict_type, self.state_dict_config): + state_dict = model.state_dict() + weights_name = f"{MODEL_NAME}.bin" if model_index == 0 else f"{MODEL_NAME}_{model_index}.bin" + output_model_file = os.path.join(output_dir, weights_name) + if accelerator.process_index == 0: + print(f"Saving model to {output_model_file}") + torch.save(state_dict, output_model_file) + print(f"Model saved to {output_model_file}") + else: + with FSDP.state_dict_type(model, self.state_dict_type, self.state_dict_config): + state_dict = model.state_dict() + weights_name = ( + f"{MODEL_NAME}_rank{accelerator.process_index}.bin" + if model_index == 0 + else f"{MODEL_NAME}_{model_index}_rank{accelerator.process_index}.bin" + ) + output_model_file = os.path.join(output_dir, weights_name) + print(f"Saving model to {output_model_file}") + torch.save(state_dict, output_model_file) + print(f"Model saved to {output_model_file}") + + def load_model(self, accelerator, model, input_dir, model_index=0): + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType + + accelerator.wait_for_everyone() + + if self.state_dict_type == StateDictType.FULL_STATE_DICT: + weights_name = f"{MODEL_NAME}.bin" if model_index == 0 else f"{MODEL_NAME}_{model_index}.bin" + input_model_file = os.path.join(input_dir, weights_name) + accelerator.print(f"Loading model from {input_model_file}") + state_dict = torch.load(input_model_file) + accelerator.print(f"Model loaded from {input_model_file}") + else: + weights_name = ( + f"{MODEL_NAME}_rank{accelerator.process_index}.bin" + if model_index == 0 + else f"{MODEL_NAME}_{model_index}_rank{accelerator.process_index}.bin" + ) + input_model_file = os.path.join(input_dir, weights_name) + print(f"Loading model from {input_model_file}") + state_dict = torch.load(input_model_file) + print(f"Model loaded from {input_model_file}") + with FSDP.state_dict_type(model, self.state_dict_type, self.state_dict_config): + model.load_state_dict(state_dict) + + def save_optimizer(self, accelerator, optimizer, model, output_dir, optimizer_index=0, optim_input=None): + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP + + optim_state = FSDP.full_optim_state_dict(model, optimizer, optim_input=optim_input) + if accelerator.process_index == 0: + optim_state_name = ( + f"{OPTIMIZER_NAME}.bin" if optimizer_index == 0 else f"{OPTIMIZER_NAME}_{optimizer_index}.bin" + ) + output_optimizer_file = os.path.join(output_dir, optim_state_name) + print(f"Saving Optimizer state to {output_optimizer_file}") + torch.save(optim_state, output_optimizer_file) + print(f"Optimizer state saved in {output_optimizer_file}") + + def load_optimizer(self, accelerator, optimizer, model, input_dir, optimizer_index=0): + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP + + accelerator.wait_for_everyone() + full_osd = None + if accelerator.process_index == 0: + optimizer_name = ( + f"{OPTIMIZER_NAME}.bin" if optimizer_index == 0 else f"{OPTIMIZER_NAME}_{optimizer_index}.bin" + ) + input_optimizer_file = os.path.join(input_dir, optimizer_name) + print(f"Loading Optimizer state from {input_optimizer_file}") + full_osd = torch.load(input_optimizer_file) + print(f"Optimizer state loaded from {input_optimizer_file}") + # called from all ranks, though only rank0 has a valid param for full_osd + sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model) + optimizer.load_state_dict(sharded_osd) diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 7b3c8de5a7a..42868a0a5c9 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -254,7 +254,11 @@ def gather_object(object: Any): """ if AcceleratorState().distributed_type == DistributedType.TPU: raise NotImplementedError("gather objects in TPU is not supported") - elif AcceleratorState().distributed_type in [DistributedType.DEEPSPEED, DistributedType.MULTI_GPU]: + elif AcceleratorState().distributed_type in [ + DistributedType.DEEPSPEED, + DistributedType.MULTI_GPU, + DistributedType.FSDP, + ]: return _gpu_gather_object(object) elif AcceleratorState().distributed_type == DistributedType.MULTI_CPU: return _cpu_gather_object(object) @@ -293,7 +297,11 @@ def broadcast(tensor, from_process: int = 0): """ if AcceleratorState().distributed_type == DistributedType.TPU: return _tpu_broadcast(tensor, src=from_process, name="accelerate.utils.broadcast") - elif AcceleratorState().distributed_type in [DistributedType.DEEPSPEED, DistributedType.MULTI_GPU]: + elif AcceleratorState().distributed_type in [ + DistributedType.DEEPSPEED, + DistributedType.MULTI_GPU, + DistributedType.FSDP, + ]: return _gpu_broadcast(tensor, src=from_process) elif AcceleratorState().distributed_type == DistributedType.MULTI_CPU: return _gpu_broadcast(tensor, src=from_process) @@ -317,7 +325,11 @@ def broadcast_object_list(object_list, from_process: int = 0): if AcceleratorState().distributed_type == DistributedType.TPU: for i, obj in enumerate(object_list): object_list[i] = xm.mesh_reduce("accelerate.utils.broadcast_object_list", obj, lambda x: x[from_process]) - elif AcceleratorState().distributed_type in [DistributedType.DEEPSPEED, DistributedType.MULTI_GPU]: + elif AcceleratorState().distributed_type in [ + DistributedType.DEEPSPEED, + DistributedType.MULTI_GPU, + DistributedType.FSDP, + ]: torch.distributed.broadcast_object_list(object_list, src=from_process) elif AcceleratorState().distributed_type == DistributedType.MULTI_CPU: torch.distributed.broadcast_object_list(object_list, src=from_process) @@ -433,7 +445,11 @@ def _reduce_across_processes(tensor, reduction="mean"): if state.distributed_type == DistributedType.TPU: xm.all_reduce("sum", cloned_tensor) return cloned_tensor - elif state.distributed_type in [DistributedType.DEEPSPEED, DistributedType.MULTI_GPU]: + elif state.distributed_type in [ + DistributedType.DEEPSPEED, + DistributedType.MULTI_GPU, + DistributedType.FSDP, + ]: torch.distributed.all_reduce(cloned_tensor, ReduceOp.SUM) return cloned_tensor else: diff --git a/src/accelerate/utils/other.py b/src/accelerate/utils/other.py index 206c7058996..ff360038d11 100644 --- a/src/accelerate/utils/other.py +++ b/src/accelerate/utils/other.py @@ -65,6 +65,7 @@ def wait_for_everyone(): AcceleratorState().distributed_type == DistributedType.MULTI_GPU or AcceleratorState().distributed_type == DistributedType.MULTI_CPU or AcceleratorState().distributed_type == DistributedType.DEEPSPEED + or AcceleratorState().distributed_type == DistributedType.FSDP ): torch.distributed.barrier() elif AcceleratorState().distributed_type == DistributedType.TPU: diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 12cca415c7a..6b37eb93e43 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -25,10 +25,18 @@ import torch from torch.utils.data import DataLoader +import accelerate from accelerate.accelerator import Accelerator from accelerate.scheduler import AcceleratedScheduler from accelerate.state import AcceleratorState -from accelerate.test_utils.testing import require_cuda, require_deepspeed +from accelerate.test_utils.testing import ( + TempDirTestCase, + execute_subprocess_async, + require_cuda, + require_deepspeed, + require_multi_gpu, + slow, +) from accelerate.test_utils.training import RegressionDataset from accelerate.utils.dataclasses import DeepSpeedPlugin from accelerate.utils.deepspeed import ( @@ -38,6 +46,7 @@ DummyOptim, DummyScheduler, ) +from accelerate.utils.other import patch_environment from parameterized import parameterized from transformers import AutoModel, AutoModelForCausalLM, get_scheduler from transformers.testing_utils import mockenv_context @@ -118,6 +127,10 @@ def setUp(self): WORLD_SIZE="1", ) + def tearDown(self): + super().tearDown() + AcceleratorState._reset_state() + def get_config_dict(self, stage): # As some tests modify the dict, always make a copy return deepcopy(self.ds_config_dict[stage]) @@ -260,11 +273,10 @@ def test_init_zero3(self): ) with mockenv_context(**self.dist_env): - accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin) + accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin) # noqa: F841 from transformers.deepspeed import is_deepspeed_zero3_enabled self.assertTrue(is_deepspeed_zero3_enabled()) - accelerator.state.initialized = False @parameterized.expand(optim_scheduler_params, name_func=parameterized_custom_name_func) def test_prepare_deepspeed(self, optim_type, scheduler_type): @@ -479,7 +491,6 @@ def test_prepare_deepspeed(self, optim_type, scheduler_type): "You can only specify `accelerate.utils.DummyScheduler` in the code when using `accelerate.utils.DummyOptim`." in str(cm.exception) ) - accelerator.state.initialized = False def test_save_checkpoints(self): deepspeed_plugin = DeepSpeedPlugin( @@ -533,7 +544,6 @@ def test_save_checkpoints(self): "To save the full checkpoint, run `model.save_checkpoint(save_dir)` and use `zero_to_fp32.py` to recover weights." ) self.assertTrue(msg in str(cm.exception)) - accelerator.state.initialized = False def test_autofill_dsconfig(self): deepspeed_plugin = DeepSpeedPlugin( @@ -581,4 +591,213 @@ def test_autofill_dsconfig(self): self.assertFalse( accelerator.deepspeed_config["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] ) - accelerator.state.initialized = False + + def test_basic_run(self): + mod_file = inspect.getfile(accelerate.test_utils) + test_file_path = os.path.sep.join(mod_file.split(os.path.sep)[:-1] + ["scripts", "test_performance.py"]) + with tempfile.TemporaryDirectory() as dirpath: + cmd = [ + "accelerate", + "launch", + "--num_processes=1", + "--num_machines=1", + "--machine_rank=0", + "--mixed_precision=fp16", + "--use_deepspeed", + "--gradient_accumulation_steps=1", + "--zero_stage=2", + "--offload_optimizer_device=none", + "--offload_param_device=none", + test_file_path, + "--model_name_or_path=distilbert-base-uncased", + "--num_epochs=1", + f"--output_dir={dirpath}", + ] + with patch_environment(omp_num_threads=1): + execute_subprocess_async(cmd, env=os.environ.copy()) + + +@require_deepspeed +@require_multi_gpu +@slow +class DeepSpeedIntegrationTest(TempDirTestCase): + def setUp(self): + super().setUp() + self._test_file_path = inspect.getfile(self.__class__) + path = Path(self._test_file_path).resolve() + self.test_file_dir_str = str(path.parents[0]) + + self.ds_config_file = dict( + zero2=f"{self.test_file_dir_str}/ds_config_zero2.json", + zero3=f"{self.test_file_dir_str}/ds_config_zero3.json", + ) + + self.stages = [1, 2, 3] + self.zero3_offload_config = False + self.performance_lower_bound = 0.83 + self.peak_memory_usage_upper_bound = { + "multi_gpu_fp16": 3200, + "deepspeed_stage_1_fp16": 1600, + "deepspeed_stage_2_fp16": 2500, + "deepspeed_stage_3_zero_init_fp16": 2800, + "deepspeed_stage_3_cpu_offload_fp16": 1900, + } + self.n_train = 160 + self.n_val = 160 + + mod_file = inspect.getfile(accelerate.test_utils) + self.test_scripts_folder = os.path.sep.join(mod_file.split(os.path.sep)[:-1] + ["scripts"]) + + def test_performance(self): + self.test_file_path = os.path.join(self.test_scripts_folder, "test_performance.py") + cmd = [ + "accelerate", + "launch", + "--num_processes=2", + "--num_machines=1", + "--machine_rank=0", + "--mixed_precision=fp16", + "--use_deepspeed", + "--gradient_accumulation_steps=1", + "--gradient_clipping=1", + "--zero3_init_flag=True", + "--zero3_save_16bit_model=True", + ] + for stage in self.stages: + if stage == 1: + continue + cmd_stage = cmd.copy() + cmd_stage.extend([f"--zero_stage={stage}"]) + cmd_stage.extend(["--offload_optimizer_device=none", "--offload_param_device=none"]) + if self.zero3_offload_config: + with io.open(self.ds_config_file[ZERO3], "r", encoding="utf-8") as f: + ds_config = json.load(f) + del ds_config["bf16"] + del ds_config["optimizer"]["params"]["torch_adam"] + del ds_config["optimizer"]["params"]["adam_w_mode"] + ds_config["fp16"]["enabled"] = True + ds_config_path = os.path.join(self.tmpdir, "ds_config.json") + with open(ds_config_path, "w") as out_file: + json.dump(ds_config, out_file) + + cmd_stage.extend([f"--deepspeed_config_file={ds_config_path}"]) + + cmd_stage.extend( + [ + self.test_file_path, + f"--output_dir={self.tmpdir}", + f"--performance_lower_bound={self.performance_lower_bound}", + ] + ) + with patch_environment(omp_num_threads=1): + execute_subprocess_async(cmd_stage, env=os.environ.copy()) + + def test_checkpointing(self): + self.test_file_path = os.path.join(self.test_scripts_folder, "test_checkpointing.py") + cmd = [ + "accelerate", + "launch", + "--num_processes=2", + "--num_machines=1", + "--machine_rank=0", + "--mixed_precision=fp16", + "--use_deepspeed", + "--gradient_accumulation_steps=1", + "--gradient_clipping=1", + "--zero3_init_flag=True", + "--zero3_save_16bit_model=True", + ] + for stage in self.stages: + if stage == 1: + continue + cmd_stage = cmd.copy() + cmd_stage.extend([f"--zero_stage={stage}"]) + cmd_stage.extend(["--offload_optimizer_device=none", "--offload_param_device=none"]) + if self.zero3_offload_config: + with io.open(self.ds_config_file[ZERO3], "r", encoding="utf-8") as f: + ds_config = json.load(f) + del ds_config["bf16"] + del ds_config["optimizer"]["params"]["torch_adam"] + del ds_config["optimizer"]["params"]["adam_w_mode"] + ds_config["fp16"]["enabled"] = True + ds_config_path = os.path.join(self.tmpdir, "ds_config.json") + with open(ds_config_path, "w") as out_file: + json.dump(ds_config, out_file) + + cmd_stage.extend([f"--deepspeed_config_file={ds_config_path}"]) + + cmd_stage.extend( + [ + self.test_file_path, + f"--output_dir={self.tmpdir}", + "--partial_train_epoch=1", + ] + ) + with patch_environment(omp_num_threads=1): + execute_subprocess_async(cmd_stage, env=os.environ.copy()) + + cmd_stage = cmd_stage[:-1] + resume_from_checkpoint = os.path.join(self.tmpdir, "epoch_0") + cmd_stage.extend( + [ + f"--resume_from_checkpoint={resume_from_checkpoint}", + ] + ) + with patch_environment(omp_num_threads=1): + execute_subprocess_async(cmd_stage, env=os.environ.copy()) + + def test_peak_memory_usage(self): + self.test_file_path = os.path.join(self.test_scripts_folder, "test_peak_memory_usage.py") + cmd = [ + "accelerate", + "launch", + "--num_processes=2", + "--num_machines=1", + "--machine_rank=0", + ] + for spec, peak_mem_upper_bound in self.peak_memory_usage_upper_bound.items(): + cmd_stage = cmd.copy() + if "fp16" in spec: + cmd_stage.extend(["--mixed_precision=fp16"]) + + if "multi_gpu" in spec: + continue + else: + cmd_stage.extend( + [ + "--use_deepspeed", + "--gradient_accumulation_steps=1", + "--gradient_clipping=1", + "--zero3_init_flag=True", + "--zero3_save_16bit_model=True", + ] + ) + for i in range(3): + if f"stage_{i+1}" in spec: + cmd_stage.extend([f"--zero_stage={i+1}"]) + break + cmd_stage.extend(["--offload_optimizer_device=none", "--offload_param_device=none"]) + if "cpu_offload" in spec: + with io.open(self.ds_config_file[ZERO3], "r", encoding="utf-8") as f: + ds_config = json.load(f) + del ds_config["bf16"] + del ds_config["fp16"] + del ds_config["optimizer"]["params"]["torch_adam"] + del ds_config["optimizer"]["params"]["adam_w_mode"] + ds_config_path = os.path.join(self.tmpdir, "ds_config.json") + with open(ds_config_path, "w") as out_file: + json.dump(ds_config, out_file) + + cmd_stage.extend([f"--deepspeed_config_file={ds_config_path}"]) + + cmd_stage.extend( + [ + self.test_file_path, + f"--output_dir={self.tmpdir}", + f"--peak_memory_upper_bound={peak_mem_upper_bound}", + f"--n_train={self.n_train}", + f"--n_val={self.n_val}", + ] + ) + with patch_environment(omp_num_threads=1): + execute_subprocess_async(cmd_stage, env=os.environ.copy()) diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py new file mode 100644 index 00000000000..8ad088c042c --- /dev/null +++ b/tests/fsdp/test_fsdp.py @@ -0,0 +1,332 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +import os +import unittest + +import torch + +import accelerate +from accelerate.accelerator import Accelerator +from accelerate.state import AcceleratorState +from accelerate.test_utils.testing import ( + TempDirTestCase, + execute_subprocess_async, + require_cuda, + require_fsdp, + require_multi_gpu, + slow, +) +from accelerate.utils.constants import ( + FSDP_AUTO_WRAP_POLICY, + FSDP_BACKWARD_PREFETCH, + FSDP_SHARDING_STRATEGY, + FSDP_STATE_DICT_TYPE, +) +from accelerate.utils.dataclasses import FullyShardedDataParallelPlugin +from accelerate.utils.other import patch_environment +from transformers import AutoModel +from transformers.testing_utils import mockenv_context +from transformers.trainer_utils import set_seed + + +set_seed(42) + +BERT_BASE_CASED = "bert-base-cased" +FP16 = "fp16" +BF16 = "bf16" +dtypes = [FP16, BF16] + + +@require_fsdp +@require_cuda +class FSDPPluginIntegration(unittest.TestCase): + def setUp(self): + super().setUp() + + self.dist_env = dict( + USE_FSDP="true", + MASTER_ADDR="localhost", + MASTER_PORT="10999", + RANK="0", + LOCAL_RANK="0", + WORLD_SIZE="1", + ) + + def tearDown(self): + super().tearDown() + AcceleratorState._reset_state() + + def test_sharding_strategy(self): + from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy + + for i, strategy in enumerate(FSDP_SHARDING_STRATEGY): + env = self.dist_env.copy() + env["FSDP_SHARDING_STRATEGY"] = f"{i + 1}" + env["FSDP_SHARDING_STRATEGY_NAME"] = strategy + with mockenv_context(**env): + fsdp_plugin = FullyShardedDataParallelPlugin() + self.assertEqual(fsdp_plugin.sharding_strategy, ShardingStrategy(i + 1)) + + def test_backward_prefetch(self): + from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch + + for i, prefetch_policy in enumerate(FSDP_BACKWARD_PREFETCH): + env = self.dist_env.copy() + env["FSDP_BACKWARD_PREFETCH"] = prefetch_policy + with mockenv_context(**env): + fsdp_plugin = FullyShardedDataParallelPlugin() + if prefetch_policy == "NO_PREFETCH": + self.assertIsNone(fsdp_plugin.backward_prefetch) + else: + self.assertEqual(fsdp_plugin.backward_prefetch, BackwardPrefetch(i + 1)) + + def test_state_dict_type(self): + from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType, _state_dict_type_to_config + + for i, state_dict_type in enumerate(FSDP_STATE_DICT_TYPE): + env = self.dist_env.copy() + env["FSDP_STATE_DICT_TYPE"] = state_dict_type + with mockenv_context(**env): + fsdp_plugin = FullyShardedDataParallelPlugin() + self.assertEqual(fsdp_plugin.state_dict_type, StateDictType(i + 1)) + self.assertEqual( + type(fsdp_plugin.state_dict_config), type(_state_dict_type_to_config[StateDictType(i + 1)]()) + ) + if state_dict_type == "FULL_STATE_DICT": + self.assertTrue(fsdp_plugin.state_dict_config.offload_to_cpu) + self.assertTrue(fsdp_plugin.state_dict_config.rank0_only) + + def test_auto_wrap_policy(self): + model = AutoModel.from_pretrained(BERT_BASE_CASED) + for policy in FSDP_AUTO_WRAP_POLICY: + env = self.dist_env.copy() + env["FSDP_AUTO_WRAP_POLICY"] = policy + if policy == "TRANSFORMER_BASED_WRAP": + env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "BertLayer" + elif policy == "SIZE_BASED_WRAP": + env["FSDP_MIN_NUM_PARAMS"] = "2000" + with mockenv_context(**env): + fsdp_plugin = FullyShardedDataParallelPlugin() + fsdp_plugin.set_auto_wrap_policy(model) + if policy == "NO_WRAP": + self.assertIsNone(fsdp_plugin.auto_wrap_policy) + else: + self.assertIsNotNone(fsdp_plugin.auto_wrap_policy) + + env = self.dist_env.copy() + env["FSDP_AUTO_WRAP_POLICY"] = "TRANSFORMER_BASED_WRAP" + env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "T5Layer" + with mockenv_context(**env): + fsdp_plugin = FullyShardedDataParallelPlugin() + with self.assertRaises(Exception) as cm: + fsdp_plugin.set_auto_wrap_policy(model) + self.assertTrue("Could not find the transformer layer class to wrap in the model." in str(cm.exception)) + + env = self.dist_env.copy() + env["FSDP_AUTO_WRAP_POLICY"] = "SIZE_BASED_WRAP" + env["FSDP_MIN_NUM_PARAMS"] = "0" + with mockenv_context(**env): + fsdp_plugin = FullyShardedDataParallelPlugin() + fsdp_plugin.set_auto_wrap_policy(model) + self.assertIsNone(fsdp_plugin.auto_wrap_policy) + + def test_mixed_precision(self): + from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision + from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler + + for mp_dtype in dtypes: + env = self.dist_env.copy() + env["MIXED_PRECISION"] = mp_dtype + with mockenv_context(**env): + accelerator = Accelerator() + if mp_dtype == "fp16": + dtype = torch.float16 + elif mp_dtype == "bf16": + dtype = torch.bfloat16 + mp_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype) + self.assertEqual(accelerator.state.fsdp_plugin.mixed_precision_policy, mp_policy) + if mp_dtype == FP16: + self.assertTrue(isinstance(accelerator.scaler, ShardedGradScaler)) + elif mp_dtype == BF16: + self.assertIsNone(accelerator.scaler) + AcceleratorState._reset_state() + + def test_cpu_offload(self): + from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload + + for flag in [True, False]: + env = self.dist_env.copy() + env["FSDP_OFFLOAD_PARAMS"] = str(flag).lower() + with mockenv_context(**env): + fsdp_plugin = FullyShardedDataParallelPlugin() + self.assertEqual(fsdp_plugin.cpu_offload, CPUOffload(offload_params=flag)) + + +@require_fsdp +@require_multi_gpu +@slow +class FSDPIntegrationTest(TempDirTestCase): + def setUp(self): + super().setUp() + self.performance_lower_bound = 0.83 + self.performance_configs = [ + "fsdp_shard_grad_op_transformer_based_wrap", + "fsdp_full_shard_transformer_based_wrap", + ] + self.peak_memory_usage_upper_bound = { + "multi_gpu_fp16": 3200, + "fsdp_shard_grad_op_transformer_based_wrap_fp16": 2000, + "fsdp_full_shard_transformer_based_wrap_fp16": 1900, + "fsdp_full_shard_cpu_offload_transformer_based_wrap_fp32": 1500, # fp16 was leading to indefinite hang + } + self.n_train = 160 + self.n_val = 160 + + mod_file = inspect.getfile(accelerate.test_utils) + self.test_scripts_folder = os.path.sep.join(mod_file.split(os.path.sep)[:-1] + ["scripts"]) + + def test_performance(self): + self.test_file_path = os.path.join(self.test_scripts_folder, "test_performance.py") + cmd = ["accelerate", "launch", "--num_processes=2", "--num_machines=1", "--machine_rank=0", "--use_fsdp"] + for config in self.performance_configs: + cmd_config = cmd.copy() + for i, strategy in enumerate(FSDP_SHARDING_STRATEGY): + if strategy.lower() in config: + cmd_config.append(f"--fsdp_sharding_strategy={i+1}") + break + + if "fp32" in config: + cmd_config.append("--mixed_precision=no") + else: + cmd_config.append("--mixed_precision=fp16") + + if "cpu_offload" in config: + cmd_config.append("--fsdp_offload_params=True") + + for policy in FSDP_AUTO_WRAP_POLICY: + if policy.lower() in config: + cmd_config.append(f"--fsdp_auto_wrap_policy={policy}") + break + + if policy == "TRANSFORMER_BASED_WRAP": + cmd_config.append("--fsdp_transformer_layer_cls_to_wrap=BertLayer") + elif policy == "SIZE_BASED_WRAP": + cmd_config.append("--fsdp_min_num_params=2000") + + cmd_config.extend( + [ + self.test_file_path, + f"--output_dir={self.tmpdir}", + f"--performance_lower_bound={self.performance_lower_bound}", + ] + ) + with patch_environment(omp_num_threads=1): + execute_subprocess_async(cmd_config, env=os.environ.copy()) + + def test_checkpointing(self): + self.test_file_path = os.path.join(self.test_scripts_folder, "test_checkpointing.py") + cmd = [ + "accelerate", + "launch", + "--num_processes=2", + "--num_machines=1", + "--machine_rank=0", + "--use_fsdp", + "--mixed_precision=fp16", + "--fsdp_transformer_layer_cls_to_wrap=BertLayer", + ] + + for i, strategy in enumerate(FSDP_SHARDING_STRATEGY): + cmd_config = cmd.copy() + cmd_config.append(f"--fsdp_sharding_strategy={i+1}") + if strategy != "FULL_SHARD": + continue + state_dict_config_index = len(cmd_config) + for state_dict_type in FSDP_STATE_DICT_TYPE: + cmd_config = cmd_config[:state_dict_config_index] + if state_dict_type == "SHARDED_STATE_DICT": + continue + cmd_config.append(f"--fsdp_state_dict_type={state_dict_type}") + cmd_config.extend( + [ + self.test_file_path, + f"--output_dir={self.tmpdir}", + "--partial_train_epoch=1", + ] + ) + with patch_environment(omp_num_threads=1): + execute_subprocess_async(cmd_config, env=os.environ.copy()) + + cmd_config = cmd_config[:-1] + resume_from_checkpoint = os.path.join(self.tmpdir, "epoch_0") + cmd_config.extend( + [ + f"--resume_from_checkpoint={resume_from_checkpoint}", + ] + ) + with patch_environment(omp_num_threads=1): + execute_subprocess_async(cmd_config, env=os.environ.copy()) + + def test_peak_memory_usage(self): + self.test_file_path = os.path.join(self.test_scripts_folder, "test_peak_memory_usage.py") + cmd = [ + "accelerate", + "launch", + "--num_processes=2", + "--num_machines=1", + "--machine_rank=0", + ] + for spec, peak_mem_upper_bound in self.peak_memory_usage_upper_bound.items(): + cmd_config = cmd.copy() + if "fp16" in spec: + cmd_config.extend(["--mixed_precision=fp16"]) + else: + cmd_config.extend(["--mixed_precision=no"]) + + if "multi_gpu" in spec: + continue + else: + cmd_config.extend(["--use_fsdp"]) + for i, strategy in enumerate(FSDP_SHARDING_STRATEGY): + if strategy.lower() in spec: + cmd_config.append(f"--fsdp_sharding_strategy={i+1}") + break + + if "cpu_offload" in spec: + cmd_config.append("--fsdp_offload_params=True") + + for policy in FSDP_AUTO_WRAP_POLICY: + if policy.lower() in spec: + cmd_config.append(f"--fsdp_auto_wrap_policy={policy}") + break + + if policy == "TRANSFORMER_BASED_WRAP": + cmd_config.append("--fsdp_transformer_layer_cls_to_wrap=BertLayer") + elif policy == "SIZE_BASED_WRAP": + cmd_config.append("--fsdp_min_num_params=2000") + + cmd_config.extend( + [ + self.test_file_path, + f"--output_dir={self.tmpdir}", + f"--peak_memory_upper_bound={peak_mem_upper_bound}", + f"--n_train={self.n_train}", + f"--n_val={self.n_val}", + ] + ) + with patch_environment(omp_num_threads=1): + execute_subprocess_async(cmd_config, env=os.environ.copy())