Skip to content

Commit

Permalink
enhancements and fixes for FSDP and DeepSpeed (#532)
Browse files Browse the repository at this point in the history
* checkpointing enhancements and fixes for FSDP and DeepSpeed

* resolving comments

1. Adding deprecation args and warnings in launcher for FSDP
2. Handling old configs to work with new launcher args wrt FSDP.
3. Reverting changes to public methods in `checkpointing.py` and handling it in `Accelerator`
4. Explicitly writing the defaults of various FSDP options in `dataclasses` for readability.

* fixes

1. FSDP wrapped model being added to the `_models`.
2. Not passing the env variables when args are None.

* resolving comments

* adding FSDP for all the collective operations

* adding deepspeed and fsdp tests

1. Removes mrpc datafiles and directly relies on HF datasets as it was throwing `file not found` error when running from within `tests` folder. Updating `moke_dataloaders` as a result.
2. adding `test_performance.py`, `test_memory.py` and `test_checkpointing.py` for multi-gpu FSDP and DeepSpeed tests

* reverting `mocked_dataloader` changes

* adding FSDP tests

* data files revert

* excluding fsdp tests from `tests_core`

* try 2

* adding time delay to avoid `torchrun` from crashing at times leading which causing flaky behaviour

* reducing the time of tests

* fixes

* fix

* fixes and reduce time further

* reduce time further and minor fixes

* adding a deepspeed basic e2e test for single gpu setup
  • Loading branch information
pacman100 authored Jul 26, 2022
1 parent 91ff425 commit 0c6bdc2
Show file tree
Hide file tree
Showing 15 changed files with 1,643 additions and 38 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ jobs:
test_core,
test_big_modeling,
test_deepspeed,
test_fsdp,
test_example_differences,
test_checkpoint_step,
test_checkpoint_epoch,
Expand Down
6 changes: 5 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,15 @@ test_big_modeling:
python -m pytest -s -v ./tests/test_big_modeling.py

test_core:
python -m pytest -s -v ./tests/ --ignore=./tests/test_examples.py --ignore=./tests/deepspeed --ignore=./tests/test_big_modeling.py
python -m pytest -s -v ./tests/ --ignore=./tests/test_examples.py --ignore=./tests/deepspeed --ignore=./tests/test_big_modeling.py \
--ignore=./tests/fsdp

test_deepspeed:
python -m pytest -s -v ./tests/deepspeed

test_fsdp:
python -m pytest -s -v ./tests/fsdp

test_examples:
python -m pytest -s -v ./tests/test_examples.py

Expand Down
84 changes: 78 additions & 6 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .state import AcceleratorState, GradientState
from .tracking import LOGGER_TYPE_TO_CLASS, GeneralTracker, filter_trackers
from .utils import (
MODEL_NAME,
DeepSpeedPlugin,
DistributedDataParallelKwargs,
DistributedType,
Expand Down Expand Up @@ -572,7 +573,8 @@ def prepare(self, *args):
if model_count > 1 and optimizer_present:
raise ValueError(
"For FSDP to work with multiple models (>1), "
"prepare must be called for all the models before optimizers are created"
"prepare must be called for all the models before optimizers are created. "
"Then pass the optimizers to the prepare call in the same order as corresponding models."
)
elif model_count == 1 and optimizer_present:
logger.warn(
Expand Down Expand Up @@ -649,6 +651,7 @@ def prepare_model(self, model):
)
if not fsdp_plugin.cpu_offload.offload_params:
model.to(self.device)
self._models[-1] = model
elif self.distributed_type == DistributedType.MULTI_CPU:
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
Expand Down Expand Up @@ -1098,9 +1101,44 @@ def save_state(self, output_dir: str):
output_dir = os.path.expanduser(output_dir)
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving current state to {output_dir}")
weights = [self.get_state_dict(m, unwrap=False) for m in self._models]

# Save the models taking care of FSDP and DeepSpeed nuances
weights = []
for i, model in enumerate(self._models):
if self.distributed_type == DistributedType.FSDP:
logger.info("Saving FSDP model")
self.state.fsdp_plugin.save_model(self, model, output_dir, i)
logger.info(f"FSDP Model saved to output dir {output_dir}")
elif self.distributed_type == DistributedType.DEEPSPEED:
logger.info("Saving DeepSpeed Model and Optimizer")
ckpt_id = f"{MODEL_NAME}" if i == 0 else f"{MODEL_NAME}_{i}"
model.save_checkpoint(output_dir, ckpt_id)
logger.info(f"DeepSpeed Model and Optimizer saved to output dir {os.path.join(output_dir, ckpt_id)}")
else:
weights.append(self.get_state_dict(model, unwrap=False))

# Save the optimizers taking care of FSDP and DeepSpeed nuances
optimizers = []
if self.distributed_type == DistributedType.FSDP:
for opt in self._optimizers:
logger.info("Saving FSDP Optimizer")
self.state.fsdp_plugin.save_optimizer(self, opt, self._models[i], output_dir, i)
logger.info(f"FSDP Optimizer saved to output dir {output_dir}")
elif self.distributed_type != DistributedType.DEEPSPEED:
optimizers = self._optimizers

# Save the lr schedulers taking care of DeepSpeed nuances
schedulers = []
if self.distributed_type == DistributedType.DEEPSPEED:
for i, scheduler in enumerate(self._schedulers):
if isinstance(scheduler, DeepSpeedSchedulerWrapper):
continue
schedulers.append(scheduler)
else:
schedulers = self._schedulers

save_location = save_accelerator_state(
output_dir, weights, self._optimizers, self._schedulers, self.state.process_index, self.scaler
output_dir, weights, optimizers, schedulers, self.state.process_index, self.scaler
)
for i, obj in enumerate(self._custom_objects):
save_custom_state(obj, output_dir, i)
Expand All @@ -1119,9 +1157,43 @@ def load_state(self, input_dir: str):
if not os.path.isdir(input_dir):
raise ValueError(f"Tried to find {input_dir} but folder does not exist")
logger.info(f"Loading states from {input_dir}")
load_accelerator_state(
input_dir, self._models, self._optimizers, self._schedulers, self.state.process_index, self.scaler
)

# Load the models taking care of FSDP and DeepSpeed nuances
models = []
for i, model in enumerate(self._models):
if self.distributed_type == DistributedType.FSDP:
logger.info("Loading FSDP model")
self.state.fsdp_plugin.load_model(self, model, input_dir, i)
logger.info(f"FSDP Model loaded from input dir {input_dir}")
elif self.distributed_type == DistributedType.DEEPSPEED:
logger.info("Loading DeepSpeed Model and Optimizer")
ckpt_id = f"{MODEL_NAME}" if i == 0 else f"{MODEL_NAME}_{i}"
model.load_checkpoint(input_dir, ckpt_id)
logger.info(f"DeepSpeed Model and Optimizer loaded from input dir {os.path.join(input_dir, ckpt_id)}")
else:
models.append(model)

# Load the optimizers taking care of FSDP and DeepSpeed nuances
optimizers = []
if self.distributed_type == DistributedType.FSDP:
for i, opt in enumerate(self._optimizers):
logger.info("Loading FSDP Optimizer")
self.state.fsdp_plugin.load_optimizer(self, opt, self._models[i], input_dir, i)
logger.info(f"FSDP Optimizer loaded from input dir {input_dir}")
elif self.distributed_type != DistributedType.DEEPSPEED:
optimizers = self._optimizers

# Load the lr schedulers taking care of DeepSpeed nuances
schedulers = []
if self.distributed_type == DistributedType.DEEPSPEED:
for i, scheduler in enumerate(self._schedulers):
if isinstance(scheduler, DeepSpeedSchedulerWrapper):
continue
schedulers.append(scheduler)
else:
schedulers = self._schedulers

load_accelerator_state(input_dir, models, optimizers, schedulers, self.state.process_index, self.scaler)
custom_checkpoints = [f for f in os.listdir(input_dir) if "custom_checkpoint" in f]
if len(custom_checkpoints) != len(self._custom_objects):
err = "Warning! Number of found checkpoints does not match the number of registered objects:"
Expand Down
22 changes: 16 additions & 6 deletions src/accelerate/commands/config/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
FSDP_AUTO_WRAP_POLICY,
FSDP_BACKWARD_PREFETCH,
FSDP_SHARDING_STRATEGY,
FSDP_STATE_DICT_TYPE,
)
from .config_args import ClusterConfig
from .config_utils import _ask_field, _convert_distributed_mode, _convert_yes_no_to_bool
Expand Down Expand Up @@ -210,12 +211,12 @@ def get_cluster_input():
for i, strategy in enumerate(FSDP_SHARDING_STRATEGY):
sharding_strategy_query += f"[{i+1}] {strategy}, "
sharding_strategy_query = sharding_strategy_query[:-2] + ")? [1]: "
fsdp_config["sharding_strategy"] = _ask_field(
fsdp_config["fsdp_sharding_strategy"] = _ask_field(
sharding_strategy_query,
lambda x: int(x),
default=1,
)
fsdp_config["offload_params"] = _ask_field(
fsdp_config["fsdp_offload_params"] = _ask_field(
"Do you want to offload parameters and gradients to CPU? [yes/NO]: ",
_convert_yes_no_to_bool,
default=False,
Expand All @@ -228,15 +229,15 @@ def get_cluster_input():
fsdp_config["fsdp_auto_wrap_policy"] = _ask_field(
fsdp_wrap_query,
lambda x: FSDP_AUTO_WRAP_POLICY[int(x)],
default=FSDP_AUTO_WRAP_POLICY[0],
default="TRANSFORMER_BASED_WRAP",
)
if fsdp_config["fsdp_auto_wrap_policy"] == FSDP_AUTO_WRAP_POLICY[0]:
fsdp_config["transformer_layer_cls_to_wrap"] = _ask_field(
fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = _ask_field(
"What is the transformer layer class name (case-sensitive) to wrap ,e.g, `BertLayer`, `GPTJBlock`, `T5Block` ...? : ",
lambda x: str(x),
)
elif fsdp_config["fsdp_auto_wrap_policy"] == FSDP_AUTO_WRAP_POLICY[1]:
fsdp_config["min_num_params"] = _ask_field(
fsdp_config["fsdp_min_num_params"] = _ask_field(
"What should be your FSDP's minimum number of parameters for Default Auto Wrapping Policy? [1e8]: ",
lambda x: int(x),
default=1e8,
Expand All @@ -248,7 +249,16 @@ def get_cluster_input():
fsdp_config["fsdp_backward_prefetch_policy"] = _ask_field(
fsdp_backward_prefetch_query,
lambda x: FSDP_BACKWARD_PREFETCH[int(x)],
default=FSDP_BACKWARD_PREFETCH[0],
default="BACKWARD_PRE",
)
fsdp_state_dict_type_query = "What should be your FSDP's state dict type ("
for i, state_dict_type in enumerate(FSDP_STATE_DICT_TYPE):
fsdp_state_dict_type_query += f"[{i}] {state_dict_type}, "
fsdp_state_dict_type_query = fsdp_state_dict_type_query[:-2] + ")? [0]: "
fsdp_config["fsdp_state_dict_type"] = _ask_field(
fsdp_state_dict_type_query,
lambda x: FSDP_STATE_DICT_TYPE[int(x)],
default="FULL_STATE_DICT",
)

if distributed_type == DistributedType.TPU:
Expand Down
93 changes: 82 additions & 11 deletions src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,19 +148,19 @@ def launch_command_parser(subparsers=None):
help="Whether to use fsdp.",
)
parser.add_argument(
"--offload_params",
"--fsdp_offload_params",
default="false",
type=str,
help="Decides Whether (true|false) to offload parameters and gradients to CPU. (useful only when `use_fsdp` flag is passed).",
)
parser.add_argument(
"--min_num_params",
"--fsdp_min_num_params",
type=int,
default=1e8,
help="FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `use_fsdp` flag is passed).",
)
parser.add_argument(
"--sharding_strategy",
"--fsdp_sharding_strategy",
type=int,
default=1,
help="FSDP's Sharding Strategy. (useful only when `use_fsdp` flag is passed).",
Expand All @@ -172,7 +172,7 @@ def launch_command_parser(subparsers=None):
help="FSDP's auto wrap policy. (useful only when `use_fsdp` flag is passed).",
)
parser.add_argument(
"--transformer_layer_cls_to_wrap",
"--fsdp_transformer_layer_cls_to_wrap",
default=None,
type=str,
help="Transformer layer class name (case-sensitive) to wrap ,e.g, `BertLayer`, `GPTJBlock`, `T5Block` .... "
Expand All @@ -184,6 +184,36 @@ def launch_command_parser(subparsers=None):
type=str,
help="FSDP's backward prefetch policy. (useful only when `use_fsdp` flag is passed).",
)
parser.add_argument(
"--fsdp_state_dict_type",
default=None,
type=str,
help="FSDP's state dict type. (useful only when `use_fsdp` flag is passed).",
)
parser.add_argument(
"--offload_params",
default=None,
type=str,
help="This argument is deprecated. Use `fsdp_offload_params` instead.",
)
parser.add_argument(
"--min_num_params",
type=int,
default=None,
help="This argument is deprecated. Use `fsdp_min_num_params` instead.",
)
parser.add_argument(
"--sharding_strategy",
type=int,
default=None,
help="This argument is deprecated. Use `fsdp_sharding_strategy` instead.",
)
parser.add_argument(
"--transformer_layer_cls_to_wrap",
default=None,
type=str,
help="This argument is deprecated. Use `fsdp_transformer_layer_cls_to_wrap` instead.",
)
parser.add_argument(
"--tpu", default=False, action="store_true", help="Whether or not this should launch a TPU training."
)
Expand Down Expand Up @@ -360,13 +390,51 @@ def multi_gpu_launcher(args):

current_env["MIXED_PRECISION"] = str(mixed_precision)
if args.use_fsdp:
if args.sharding_strategy is not None:
warnings.warn(
"`sharding_strategy` is deprecated and will be removed in version 0.13.0 of 🤗 Accelerate. Use"
" `fsdp_sharding_strategy` instead",
FutureWarning,
)
args.fsdp_sharding_strategy = args.sharding_strategy

if args.offload_params is not None:
warnings.warn(
"`offload_params` is deprecated and will be removed in version 0.13.0 of 🤗 Accelerate. Use"
" `fsdp_offload_params` instead",
FutureWarning,
)
args.fsdp_offload_params = args.offload_params

if args.min_num_params is not None:
warnings.warn(
"`min_num_params` is deprecated and will be removed in version 0.13.0 of 🤗 Accelerate. Use"
" `fsdp_min_num_params` instead",
FutureWarning,
)
args.fsdp_min_num_params = args.min_num_params

if args.transformer_layer_cls_to_wrap is not None:
warnings.warn(
"`transformer_layer_cls_to_wrap` is deprecated and will be removed in version 0.13.0 of 🤗 Accelerate. Use"
" `fsdp_transformer_layer_cls_to_wrap` instead",
FutureWarning,
)
args.fsdp_transformer_layer_cls_to_wrap = args.transformer_layer_cls_to_wrap

current_env["USE_FSDP"] = "true"
current_env["FSDP_AUTO_WRAP_POLICY"] = str(args.fsdp_auto_wrap_policy)
current_env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = str(args.transformer_layer_cls_to_wrap)
current_env["FSDP_OFFLOAD_PARAMS"] = str(args.offload_params).lower()
current_env["FSDP_MIN_NUM_PARAMS"] = str(args.min_num_params)
current_env["FSDP_SHARDING_STRATEGY"] = str(args.sharding_strategy)
current_env["FSDP_BACKWARD_PREFETCH"] = str(args.fsdp_backward_prefetch_policy)
current_env["FSDP_SHARDING_STRATEGY"] = str(args.fsdp_sharding_strategy)
current_env["FSDP_OFFLOAD_PARAMS"] = str(args.fsdp_offload_params).lower()
current_env["FSDP_MIN_NUM_PARAMS"] = str(args.fsdp_min_num_params)
if args.fsdp_auto_wrap_policy is not None:
current_env["FSDP_AUTO_WRAP_POLICY"] = str(args.fsdp_auto_wrap_policy)
if args.fsdp_transformer_layer_cls_to_wrap is not None:
current_env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = str(args.fsdp_transformer_layer_cls_to_wrap)
if args.fsdp_backward_prefetch_policy is not None:
current_env["FSDP_BACKWARD_PREFETCH"] = str(args.fsdp_backward_prefetch_policy)
if args.fsdp_state_dict_type is not None:
current_env["FSDP_STATE_DICT_TYPE"] = str(args.fsdp_state_dict_type)

current_env["OMP_NUM_THREADS"] = str(args.num_cpu_threads_per_process)
process = subprocess.Popen(cmd, env=current_env)
process.wait()
Expand Down Expand Up @@ -682,7 +750,10 @@ def launch_command(args):
if getattr(args, k) is None:
setattr(args, k, defaults.deepspeed_config[k])
for k in defaults.fsdp_config:
setattr(args, k, defaults.fsdp_config[k])
arg_to_set = k
if "fsdp" not in arg_to_set:
arg_to_set = "fsdp_" + arg_to_set
setattr(args, arg_to_set, defaults.fsdp_config[k])
continue

# Those args are handled separately
Expand Down
Loading

0 comments on commit 0c6bdc2

Please sign in to comment.