Skip to content

Commit

Permalink
[DO NOT REVIEW] gaps to enable FDSP2 cpu offloading
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
weifengpy committed Oct 16, 2024
1 parent 3858dc9 commit 5861e28
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 8 deletions.
7 changes: 6 additions & 1 deletion torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,12 @@ def init_weights(self):
``init_weights``. We only call it in the constructor of this
``Transformer`` root module to avoid reinitializing tensors.
"""
with torch.device(self.freqs_cis.device):
# freqs_cis is not a module parameter, just a plain python class attribute
# it's not managed by model.parameters()
# thus FSDP2 does not manange it (shard, reshard, cpu offload, move to gpu)
# need to come up with a long-term solution
# hard code "cuda" device for now
with torch.device("cuda"):
self.freqs_cis = self._precompute_freqs_cis()
if self.tok_embeddings is not None:
nn.init.normal_(self.tok_embeddings.weight)
Expand Down
13 changes: 11 additions & 2 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
import torch
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._composable.fsdp import (
CPUOffloadPolicy,
fully_shard,
MixedPrecisionPolicy,
)
from torch.distributed._composable.replicate import replicate
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
Expand Down Expand Up @@ -304,7 +308,12 @@ def apply_fsdp(
Apply data parallelism to the model. FSDP2 is used here.
"""
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
# TODO: add config_manager for offload_policy
fsdp_config = {
"mesh": dp_mesh,
"mp_policy": mp_policy,
"offload_policy": CPUOffloadPolicy(),
}

# TODO: remove this check once PyTorch 2.5 is released. We can safely assume
# that users won't use a nightly build which is older than 20240809 by then.
Expand Down
44 changes: 39 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ def loss_fn(pred, labels):
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)

# move sharded model to CPU/GPU and initialize weights via DTensor
init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
# init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
# TODO: check job_config.cpu_offload = True/False
init_device = "cpu"
model.to_empty(device=init_device)
model.init_weights()
model.train()
Expand Down Expand Up @@ -305,10 +307,42 @@ def loss_fn(pred, labels):
loss.backward()

# clip gradients
for m in model_parts:
torch.nn.utils.clip_grad_norm_(
m.parameters(), job_config.training.max_norm, foreach=True
)
# need to resolve following error for cpu offload
# File "/data/users/weif/torchtitan/train.py", line 312, in main
# torch.nn.utils.clip_grad_norm_(
# File "/data/users/weif/pytorch/torch/nn/utils/clip_grad.py", line 30, in _no_grad_wrapper
# return func(*args, **kwargs)
# File "/data/users/weif/pytorch/torch/nn/utils/clip_grad.py", line 105, in clip_grad_norm_
# clip_coef = max_norm / (total_norm + 1e-6)
# File "/data/users/weif/pytorch/torch/_tensor.py", line 39, in wrapped
# return f(*args, **kwargs)
# File "/data/users/weif/pytorch/torch/_tensor.py", line 1064, in __rdiv__
# return self.reciprocal() * other
# File "/data/users/weif/pytorch/torch/_compile.py", line 32, in inner
# return disable_fn(*args, **kwargs)
# File "/data/users/weif/pytorch/torch/_dynamo/eval_frame.py", line 629, in _fn
# return fn(*args, **kwargs)
# File "/data/users/weif/pytorch/torch/distributed/tensor/_api.py", line 340, in __torch_dispatch__
# return DTensor._op_dispatcher.dispatch(
# File "/data/users/weif/pytorch/torch/distributed/tensor/_dispatch.py", line 181, in dispatch
# self.redistribute_local_args(
# File "/data/users/weif/pytorch/torch/distributed/tensor/_dispatch.py", line 317, in redistribute_local_args
# resharded_local_tensor = redistribute_local_tensor(
# File "/data/users/weif/pytorch/torch/distributed/tensor/_redistribute.py", line 208, in redistribute_local_tensor
# new_local_tensor = partial_spec._reduce_value(
# File "/data/users/weif/pytorch/torch/distributed/tensor/_ops/_math_ops.py", line 126, in _reduce_value
# reduced_tensor = super()._reduce_value(tensor, mesh, mesh_dim)
# File "/data/users/weif/pytorch/torch/distributed/tensor/placement_types.py", line 599, in _reduce_value
# return funcol.all_reduce(
# File "/data/users/weif/pytorch/torch/distributed/_functional_collectives.py", line 175, in all_reduce
# tensor = torch.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name)
# File "/data/users/weif/pytorch/torch/_ops.py", line 1123, in __call__
# return self._op(*args, **(kwargs or {}))
# RuntimeError: No backend type associated with device type cpu
# for m in model_parts:
# torch.nn.utils.clip_grad_norm_(
# m.parameters(), job_config.training.max_norm, foreach=True
# )

# sync float8 amaxes and scales
float8_handler.sync_float8_amax_and_scale_history(model_parts)
Expand Down

0 comments on commit 5861e28

Please sign in to comment.