Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP2] full finetune: move state dict to cpu when cpu offloading #1495

Merged
merged 6 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion recipes/dev/lora_finetune_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def setup(self, cfg: DictConfig) -> None:
self._metric_logger.log_config(cfg)

checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
self._model_compile = cfg.compile
self._model_compile = cfg.get("compile", False)

self._model = self._setup_model(
cfg_model=cfg.model,
Expand Down Expand Up @@ -521,12 +521,14 @@ def save_checkpoint(
cpu_state_dict = training.get_full_model_state_dict(
self._model,
self._is_rank_zero,
device=self._device,
)

if intermediate_checkpoint:
opt_state_dict = training.get_full_optimizer_state_dict(
self._optimizer,
self._is_rank_zero,
device=self._device,
)
else:
opt_state_dict = None
Expand Down
9 changes: 8 additions & 1 deletion recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,12 @@ def _is_layer_fqn(s: str) -> bool:
# This method will convert the full model state dict into a sharded state
# dict and load into the model
training.load_from_full_model_state_dict(
model, model_state_dict, self._device, self._is_rank_zero, strict=True
model,
model_state_dict,
self._device,
self._is_rank_zero,
strict=True,
cpu_offload=fsdp_cpu_offload,
)

# Ensure no params and buffers are on meta device
Expand Down Expand Up @@ -532,12 +537,14 @@ def save_checkpoint(
cpu_state_dict = training.get_full_model_state_dict(
self._model,
self._is_rank_zero,
device=self._device,
)

if intermediate_checkpoint:
opt_state_dict = training.get_full_optimizer_state_dict(
self._optimizer,
self._is_rank_zero,
device=self._device,
)
else:
opt_state_dict = None
Expand Down
25 changes: 22 additions & 3 deletions torchtune/training/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
import os
from itertools import chain
from typing import Any, Callable, cast, Dict, List, Set, Tuple, Type
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -285,6 +285,7 @@ def load_from_full_model_state_dict(
device: torch.device,
is_rank_zero: bool,
strict: bool = False,
cpu_offload: bool = False,
):
"""
Converting full state dict into a sharded state dict
Expand Down Expand Up @@ -338,6 +339,8 @@ def load_from_full_model_state_dict(
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
)
if cpu_offload:
sharded_tensor = sharded_tensor.cpu()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sharded_tensor have device=cuda because distribute_tensor/DTensor requires NCCL. For cpu offloading, we can move DTensor to device=cpu afterwards to avoid peaking memory

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change makes sense to me. I am not sure if we should support the user trying to load a GPU state dict into an FSDP module that has CPU offloading enabled. We can provide a better error message though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got you. this reminds me of my overdued BE to check cpu device in lazy_init when cpu offloading is enabled

sharded_sd[param_name] = nn.Parameter(sharded_tensor)
# choose `assign=True` since we cannot call `copy_` on meta tensor
return model.load_state_dict(sharded_sd, strict=strict, assign=True)
Expand All @@ -346,6 +349,7 @@ def load_from_full_model_state_dict(
def get_full_model_state_dict(
model: "FSDPModule", # noqa
is_rank_zero: bool,
device: Optional[torch.device] = None,
) -> Dict[str, Any]:
"""
Converting sharded state dict into a full state dict on cpu
Expand Down Expand Up @@ -393,6 +397,13 @@ def get_full_model_state_dict(
module.reshard()
else:
for param_name, sharded_param in sharded_sd.items():
if sharded_param.is_cpu:
assert device is not None and device.type == "cuda", (
f"Expect cuda but got device={device}. "
"Please call get_full_model_state_dict(..., device=self._device),"
" so DTensor can communicate over NCCL."
)
sharded_param = sharded_param.to(device)
full_param = sharded_param.full_tensor()
if is_rank_zero:
cpu_state_dict[param_name] = full_param.cpu()
Expand All @@ -404,6 +415,7 @@ def get_full_model_state_dict(
def get_full_optimizer_state_dict(
opt: Optimizer,
is_rank_zero: bool,
device: Optional[torch.device] = None,
) -> Dict[str, Any]:
"""
Converting optimizer state from sharded to full
Expand All @@ -417,8 +429,15 @@ def get_full_optimizer_state_dict(
for group_id, sharded_group in sharded_state.items():
group_state = {}
for attr, sharded_tensor in sharded_group.items():
# "exp_avg" in AdamW is `DTensor`
if isinstance(sharded_tensor, DTensor):
# "exp_avg" in AdamW is `DTensor`
if sharded_tensor.is_cpu:
assert device is not None and device.type == "cuda", (
f"Expect cuda but got device={device}. "
"Please call get_full_optimizer_state_dict(..., device=self._device),"
" so DTensor can communicate over NCCL."
)
sharded_tensor = sharded_tensor.to(device)
full_tensor = sharded_tensor.full_tensor()
else:
# "step" in AdamW is plain tensor
Expand Down Expand Up @@ -584,4 +603,4 @@ def shard_model(
fully_shard(m, **fsdp_kwargs)

# Finally shard the entire model to account for any stragglers
fully_shard(model)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to be dropping **fsdp_kwargs by accident? it prevents cpu_offloading for root model

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I had meant to fix this, thanks for adding it here!

fully_shard(model, **fsdp_kwargs)
Loading