-
Notifications
You must be signed in to change notification settings - Fork 530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP2] full finetune: move state dict to cpu when cpu offloading #1495
Changes from all commits
8adb728
4582779
afec67c
ee04e93
640b96f
fedaa32
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
import logging | ||
import os | ||
from itertools import chain | ||
from typing import Any, Callable, cast, Dict, List, Set, Tuple, Type | ||
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
@@ -285,6 +285,7 @@ def load_from_full_model_state_dict( | |
device: torch.device, | ||
is_rank_zero: bool, | ||
strict: bool = False, | ||
cpu_offload: bool = False, | ||
): | ||
""" | ||
Converting full state dict into a sharded state dict | ||
|
@@ -338,6 +339,8 @@ def load_from_full_model_state_dict( | |
sharded_meta_param.device_mesh, | ||
sharded_meta_param.placements, | ||
) | ||
if cpu_offload: | ||
sharded_tensor = sharded_tensor.cpu() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change makes sense to me. I am not sure if we should support the user trying to load a GPU state dict into an FSDP module that has CPU offloading enabled. We can provide a better error message though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. got you. this reminds me of my overdued BE to check cpu device in lazy_init when cpu offloading is enabled |
||
sharded_sd[param_name] = nn.Parameter(sharded_tensor) | ||
# choose `assign=True` since we cannot call `copy_` on meta tensor | ||
return model.load_state_dict(sharded_sd, strict=strict, assign=True) | ||
|
@@ -346,6 +349,7 @@ def load_from_full_model_state_dict( | |
def get_full_model_state_dict( | ||
model: "FSDPModule", # noqa | ||
is_rank_zero: bool, | ||
device: Optional[torch.device] = None, | ||
) -> Dict[str, Any]: | ||
""" | ||
Converting sharded state dict into a full state dict on cpu | ||
|
@@ -393,6 +397,13 @@ def get_full_model_state_dict( | |
module.reshard() | ||
else: | ||
for param_name, sharded_param in sharded_sd.items(): | ||
if sharded_param.is_cpu: | ||
assert device is not None and device.type == "cuda", ( | ||
f"Expect cuda but got device={device}. " | ||
"Please call get_full_model_state_dict(..., device=self._device)," | ||
" so DTensor can communicate over NCCL." | ||
) | ||
sharded_param = sharded_param.to(device) | ||
full_param = sharded_param.full_tensor() | ||
if is_rank_zero: | ||
cpu_state_dict[param_name] = full_param.cpu() | ||
|
@@ -404,6 +415,7 @@ def get_full_model_state_dict( | |
def get_full_optimizer_state_dict( | ||
opt: Optimizer, | ||
is_rank_zero: bool, | ||
device: Optional[torch.device] = None, | ||
) -> Dict[str, Any]: | ||
""" | ||
Converting optimizer state from sharded to full | ||
|
@@ -417,8 +429,15 @@ def get_full_optimizer_state_dict( | |
for group_id, sharded_group in sharded_state.items(): | ||
group_state = {} | ||
for attr, sharded_tensor in sharded_group.items(): | ||
# "exp_avg" in AdamW is `DTensor` | ||
if isinstance(sharded_tensor, DTensor): | ||
# "exp_avg" in AdamW is `DTensor` | ||
if sharded_tensor.is_cpu: | ||
assert device is not None and device.type == "cuda", ( | ||
f"Expect cuda but got device={device}. " | ||
"Please call get_full_optimizer_state_dict(..., device=self._device)," | ||
" so DTensor can communicate over NCCL." | ||
) | ||
sharded_tensor = sharded_tensor.to(device) | ||
full_tensor = sharded_tensor.full_tensor() | ||
else: | ||
# "step" in AdamW is plain tensor | ||
|
@@ -584,4 +603,4 @@ def shard_model( | |
fully_shard(m, **fsdp_kwargs) | ||
|
||
# Finally shard the entire model to account for any stragglers | ||
fully_shard(model) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this seems to be dropping There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I had meant to fix this, thanks for adding it here! |
||
fully_shard(model, **fsdp_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sharded_tensor
havedevice=cuda
becausedistribute_tensor
/DTensor requires NCCL. For cpu offloading, we can move DTensor todevice=cpu
afterwards to avoid peaking memory