Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT REVIEW] gaps to enable FDSP2 cpu offloading #622

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,12 @@ def init_weights(self):
``init_weights``. We only call it in the constructor of this
``Transformer`` root module to avoid reinitializing tensors.
"""
with torch.device(self.freqs_cis.device):
# freqs_cis is not a module parameter, just a plain python class attribute
# it's not managed by model.parameters()
# thus FSDP2 does not manange it (shard, reshard, cpu offload, move to gpu)
# need to come up with a long-term solution
# hard code "cuda" device for now
with torch.device("cuda"):
self.freqs_cis = self._precompute_freqs_cis()
if self.tok_embeddings is not None:
nn.init.normal_(self.tok_embeddings.weight)
Expand Down
13 changes: 11 additions & 2 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
import torch
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._composable.fsdp import (
CPUOffloadPolicy,
fully_shard,
MixedPrecisionPolicy,
)
from torch.distributed._composable.replicate import replicate
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
Expand Down Expand Up @@ -304,7 +308,12 @@ def apply_fsdp(
Apply data parallelism to the model. FSDP2 is used here.
"""
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
# TODO: add config_manager for offload_policy
fsdp_config = {
"mesh": dp_mesh,
"mp_policy": mp_policy,
"offload_policy": CPUOffloadPolicy(),
}

# TODO: remove this check once PyTorch 2.5 is released. We can safely assume
# that users won't use a nightly build which is older than 20240809 by then.
Expand Down
44 changes: 39 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ def loss_fn(pred, labels):
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)

# move sharded model to CPU/GPU and initialize weights via DTensor
init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
# init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
# TODO: check job_config.cpu_offload = True/False
init_device = "cpu"
model.to_empty(device=init_device)
model.init_weights()
model.train()
Expand Down Expand Up @@ -305,10 +307,42 @@ def loss_fn(pred, labels):
loss.backward()

# clip gradients
for m in model_parts:
torch.nn.utils.clip_grad_norm_(
m.parameters(), job_config.training.max_norm, foreach=True
)
# need to resolve following error for cpu offload
# File "/data/users/weif/torchtitan/train.py", line 312, in main
# torch.nn.utils.clip_grad_norm_(
# File "/data/users/weif/pytorch/torch/nn/utils/clip_grad.py", line 30, in _no_grad_wrapper
# return func(*args, **kwargs)
# File "/data/users/weif/pytorch/torch/nn/utils/clip_grad.py", line 105, in clip_grad_norm_
# clip_coef = max_norm / (total_norm + 1e-6)
# File "/data/users/weif/pytorch/torch/_tensor.py", line 39, in wrapped
# return f(*args, **kwargs)
# File "/data/users/weif/pytorch/torch/_tensor.py", line 1064, in __rdiv__
# return self.reciprocal() * other
# File "/data/users/weif/pytorch/torch/_compile.py", line 32, in inner
# return disable_fn(*args, **kwargs)
# File "/data/users/weif/pytorch/torch/_dynamo/eval_frame.py", line 629, in _fn
# return fn(*args, **kwargs)
# File "/data/users/weif/pytorch/torch/distributed/tensor/_api.py", line 340, in __torch_dispatch__
# return DTensor._op_dispatcher.dispatch(
# File "/data/users/weif/pytorch/torch/distributed/tensor/_dispatch.py", line 181, in dispatch
# self.redistribute_local_args(
# File "/data/users/weif/pytorch/torch/distributed/tensor/_dispatch.py", line 317, in redistribute_local_args
# resharded_local_tensor = redistribute_local_tensor(
# File "/data/users/weif/pytorch/torch/distributed/tensor/_redistribute.py", line 208, in redistribute_local_tensor
# new_local_tensor = partial_spec._reduce_value(
# File "/data/users/weif/pytorch/torch/distributed/tensor/_ops/_math_ops.py", line 126, in _reduce_value
# reduced_tensor = super()._reduce_value(tensor, mesh, mesh_dim)
# File "/data/users/weif/pytorch/torch/distributed/tensor/placement_types.py", line 599, in _reduce_value
# return funcol.all_reduce(
# File "/data/users/weif/pytorch/torch/distributed/_functional_collectives.py", line 175, in all_reduce
# tensor = torch.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name)
# File "/data/users/weif/pytorch/torch/_ops.py", line 1123, in __call__
# return self._op(*args, **(kwargs or {}))
# RuntimeError: No backend type associated with device type cpu
# for m in model_parts:
# torch.nn.utils.clip_grad_norm_(
# m.parameters(), job_config.training.max_norm, foreach=True
# )

# sync float8 amaxes and scales
float8_handler.sync_float8_amax_and_scale_history(model_parts)
Expand Down
Loading