Skip to content

Commit

Permalink
Handle Manual Wrapping in FSDP. Minor fix of fsdp example. (#342)
Browse files Browse the repository at this point in the history
* Handle manual wrapping in FSDP. Fix fsdp example.
  • Loading branch information
pacman100 authored May 5, 2022
1 parent 603a53f commit be0f7ce
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 23 deletions.
26 changes: 14 additions & 12 deletions examples/by_feature/fsdp_with_peak_mem_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,13 @@ def collate_fn(examples):
)
)
# Logging the peak memory usage of the GPU to the tracker
accelerator.log(
{
"train_total_peak_memory": tracemalloc.peaked + b2mb(tracemalloc.begin),
},
step=epoch,
)
if args.with_tracking:
accelerator.log(
{
"train_total_peak_memory": tracemalloc.peaked + b2mb(tracemalloc.begin),
},
step=epoch,
)

# New Code #
# context manager to track the peak memory usage during the evaluation
Expand Down Expand Up @@ -302,12 +303,13 @@ def collate_fn(examples):
"Total Peak Memory consumed during the eval (max): {}".format(tracemalloc.peaked + b2mb(tracemalloc.begin))
)
# Logging the peak memory usage of the GPU to the tracker
accelerator.log(
{
"eval_total_peak_memory": tracemalloc.peaked + b2mb(tracemalloc.begin),
},
step=epoch,
)
if args.with_tracking:
accelerator.log(
{
"eval_total_peak_memory": tracemalloc.peaked + b2mb(tracemalloc.begin),
},
step=epoch,
)

if args.with_tracking:
accelerator.end_training()
Expand Down
25 changes: 14 additions & 11 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,17 +465,20 @@ def prepare_model(self, model):
elif self.distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP

fsdp_plugin = self.state.fsdp_plugin
model = FSDP(
model,
sharding_strategy=fsdp_plugin.sharding_strategy,
cpu_offload=fsdp_plugin.cpu_offload,
auto_wrap_policy=fsdp_plugin.auto_wrap_policy,
backward_prefetch=fsdp_plugin.backward_prefetch,
ignored_modules=fsdp_plugin.ignored_modules,
)
if not fsdp_plugin.cpu_offload.offload_params:
model.to(self.device)
# Check if the model is already a FSDP model due to `Manual Wrapping` and if so,
# don't wrap it again
if type(model) != FSDP:
fsdp_plugin = self.state.fsdp_plugin
model = FSDP(
model,
sharding_strategy=fsdp_plugin.sharding_strategy,
cpu_offload=fsdp_plugin.cpu_offload,
auto_wrap_policy=fsdp_plugin.auto_wrap_policy,
backward_prefetch=fsdp_plugin.backward_prefetch,
ignored_modules=fsdp_plugin.ignored_modules,
)
if not fsdp_plugin.cpu_offload.offload_params:
model.to(self.device)
elif self.distributed_type == DistributedType.MULTI_CPU:
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
Expand Down

0 comments on commit be0f7ce

Please sign in to comment.