diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 7f102a80..20056fbd 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -391,7 +391,12 @@ def init_weights(self): ``init_weights``. We only call it in the constructor of this ``Transformer`` root module to avoid reinitializing tensors. """ - with torch.device(self.freqs_cis.device): + # freqs_cis is not a module parameter, just a plain python class attribute + # it's not managed by model.parameters() + # thus FSDP2 does not manange it (shard, reshard, cpu offload, move to gpu) + # need to come up with a long-term solution + # hard code "cuda" device for now + with torch.device("cuda"): self.freqs_cis = self._precompute_freqs_cis() if self.tok_embeddings is not None: nn.init.normal_(self.tok_embeddings.weight) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index fc26703d..0db99018 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -12,7 +12,11 @@ import torch import torch.nn as nn from torch.distributed import DeviceMesh -from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy +from torch.distributed._composable.fsdp import ( + CPUOffloadPolicy, + fully_shard, + MixedPrecisionPolicy, +) from torch.distributed._composable.replicate import replicate from torch.distributed._tensor import Replicate, Shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( @@ -304,7 +308,12 @@ def apply_fsdp( Apply data parallelism to the model. FSDP2 is used here. """ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) - fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + # TODO: add config_manager for offload_policy + fsdp_config = { + "mesh": dp_mesh, + "mp_policy": mp_policy, + "offload_policy": CPUOffloadPolicy(), + } # TODO: remove this check once PyTorch 2.5 is released. We can safely assume # that users won't use a nightly build which is older than 20240809 by then. diff --git a/train.py b/train.py index 3e8994a3..ed3598a6 100644 --- a/train.py +++ b/train.py @@ -170,7 +170,9 @@ def loss_fn(pred, labels): models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config) # move sharded model to CPU/GPU and initialize weights via DTensor - init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda" + # init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda" + # TODO: check job_config.cpu_offload = True/False + init_device = "cpu" model.to_empty(device=init_device) model.init_weights() model.train() @@ -305,10 +307,42 @@ def loss_fn(pred, labels): loss.backward() # clip gradients - for m in model_parts: - torch.nn.utils.clip_grad_norm_( - m.parameters(), job_config.training.max_norm, foreach=True - ) + # need to resolve following error for cpu offload + # File "/data/users/weif/torchtitan/train.py", line 312, in main + # torch.nn.utils.clip_grad_norm_( + # File "/data/users/weif/pytorch/torch/nn/utils/clip_grad.py", line 30, in _no_grad_wrapper + # return func(*args, **kwargs) + # File "/data/users/weif/pytorch/torch/nn/utils/clip_grad.py", line 105, in clip_grad_norm_ + # clip_coef = max_norm / (total_norm + 1e-6) + # File "/data/users/weif/pytorch/torch/_tensor.py", line 39, in wrapped + # return f(*args, **kwargs) + # File "/data/users/weif/pytorch/torch/_tensor.py", line 1064, in __rdiv__ + # return self.reciprocal() * other + # File "/data/users/weif/pytorch/torch/_compile.py", line 32, in inner + # return disable_fn(*args, **kwargs) + # File "/data/users/weif/pytorch/torch/_dynamo/eval_frame.py", line 629, in _fn + # return fn(*args, **kwargs) + # File "/data/users/weif/pytorch/torch/distributed/tensor/_api.py", line 340, in __torch_dispatch__ + # return DTensor._op_dispatcher.dispatch( + # File "/data/users/weif/pytorch/torch/distributed/tensor/_dispatch.py", line 181, in dispatch + # self.redistribute_local_args( + # File "/data/users/weif/pytorch/torch/distributed/tensor/_dispatch.py", line 317, in redistribute_local_args + # resharded_local_tensor = redistribute_local_tensor( + # File "/data/users/weif/pytorch/torch/distributed/tensor/_redistribute.py", line 208, in redistribute_local_tensor + # new_local_tensor = partial_spec._reduce_value( + # File "/data/users/weif/pytorch/torch/distributed/tensor/_ops/_math_ops.py", line 126, in _reduce_value + # reduced_tensor = super()._reduce_value(tensor, mesh, mesh_dim) + # File "/data/users/weif/pytorch/torch/distributed/tensor/placement_types.py", line 599, in _reduce_value + # return funcol.all_reduce( + # File "/data/users/weif/pytorch/torch/distributed/_functional_collectives.py", line 175, in all_reduce + # tensor = torch.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name) + # File "/data/users/weif/pytorch/torch/_ops.py", line 1123, in __call__ + # return self._op(*args, **(kwargs or {})) + # RuntimeError: No backend type associated with device type cpu + # for m in model_parts: + # torch.nn.utils.clip_grad_norm_( + # m.parameters(), job_config.training.max_norm, foreach=True + # ) # sync float8 amaxes and scales float8_handler.sync_float8_amax_and_scale_history(model_parts)