Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full finetune recipe not working with FSDP2 CPU offload #1412

Open
ebsmothers opened this issue Aug 25, 2024 · 10 comments · Fixed by #1495
Open

Full finetune recipe not working with FSDP2 CPU offload #1412

ebsmothers opened this issue Aug 25, 2024 · 10 comments · Fixed by #1495

Comments

@ebsmothers
Copy link
Contributor

As in the title.. I spent a bit of time debugging it but haven't figured out the cause yet. E.g. running

tune run --nproc_per_node 2 full_finetune_distributed --config llama2/7B_full fsdp_cpu_offload=True

gives the error

RuntimeError: attempting to assign a gradient with device type 'cpu' to a tensor with device type 'cuda'. Please ensure that the gradient and the tensor are on the same device

We aren't passing the CPU offload policy to the top-level sharding of the model here, which also needs to be fixed. But even after this change, the error persists. Probably missing something obvious here..

Once we fix, we need to add a test for CPU offload as well to catch this for next time.

@ebsmothers
Copy link
Contributor Author

cc @weifengpy is there extra initialization logic we need to do for proper CPU offloading support? I know in the LoRA recipe you explicitly call to_empty on CPU for LoRA params, but thought that was only necessary if we aren't loading those params from the state dict

@weifengpy
Copy link
Contributor

weifengpy commented Aug 26, 2024

is there extra initialization logic we need to do for proper CPU offloading support?

I do not expect extra init logic. if we set fully_shard(cpu_offload=), FSDP2 will move parameters to cpu during state dict loading. users do not need to call to_empty or .cpu() or anything. https://github.com/pytorch/pytorch/blob/ed86ac2f25e7ae5d4e432e1c85837e83f4bd17a6/torch/distributed/_composable/fsdp/_fsdp_param.py#L720

btw: I tested your command and it works on my machine, tune run --nproc_per_node 2 full_finetune_distributed --config llama2/7B_full fsdp_cpu_offload=True. verified model.parameters() are on cpu at optimizer.step

if you want to debug at local, a good starting point is figuring out the the fqn for the problematic param. For example, print [(name, param.device) for name, param in model.named_parameters()] at forward, backward, optimizer.step

@ebsmothers
Copy link
Contributor Author

@weifengpy thanks for taking a look. After forward, all params are on CUDA, not CPU. After that the error occurs during backward though (specifically during FSDP's post_backward). Btw when you tested were you on PyTorch nightlies?

@Delaunay
Copy link

Delaunay commented Aug 30, 2024

I have the same issue. The official released version on PyPI works, but main does not.

The recipe and the config are almost the verbatim version of torchtune own recipe with very light modification for instrumentation, so I doubt the issue is coming from the modification (and it used to work on the released version)


tune run --nnodes=1 --rdzv-backend=c10d --rdzv-endpoint=cn-n001.server.mila.quebec:29400 --nproc-per-node=8 -- 
$RECIPES/recipes/full_finetune_distributed.py --config $CONFIG/configs/llama3_70B_full.yaml epochs=1 
output_dir=$OUTPUT/output tokenizer.path=$DATA/llama3_70B/original/tokenizer.model 
checkpointer.checkpoint_dir=$DATA/llama3_70B checkpointer.output_dir=$DATA/llama3_70B/ 
metric_logger.log_dir=$DATA/metrics 'repo_id="meta-llama/Meta-Llama-3.1-70B"' safetensors=true batch_size=2 gradient_accumulation_steps=1 [at 2024-08-30 
14:33:57.737993]
    * 1 x E0830 14:37:56.528000 140676060272448 torch/distributed/elastic/multiprocessing/api.py:833] failed (exitcode: 1) local_rank: 0 (pid: 2740964) of binary: /tmp/workspace/cuda/results/venv/torch/bin/python
        | [rank7]: Traceback (most recent call last):
        | [rank7]:   File "/home/mila/d/delaunap/milabench/benchmarks/llm/recipes/full_finetune_distributed.py", line 781, in <module>
        | [rank7]:     sys.exit(recipe_main())
        | [rank7]:   File "/home/mila/d/delaunap/milabench/benchmarks/llm/tune/torchtune/config/_parse.py", line 99, in wrapper
        | [rank7]:     sys.exit(recipe_main(conf))
        | [rank7]:   File "/home/mila/d/delaunap/milabench/benchmarks/llm/recipes/full_finetune_distributed.py", line 774, in recipe_main
        | [rank7]:     recipe.train()
        | [rank7]:   File "/home/mila/d/delaunap/milabench/benchmarks/llm/recipes/full_finetune_distributed.py", line 650, in train
        | [rank7]:     loss.backward()
        | [rank7]:   File "/tmp/workspace/cuda/results/venv/torch/lib/python3.10/site-packages/torch/_tensor.py", line 521, in backward
        | [rank7]:     torch.autograd.backward(
        | [rank7]:   File "/tmp/workspace/cuda/results/venv/torch/lib/python3.10/site-packages/torch/autograd/__init__.py", line 289, in backward
        | [rank7]:     _engine_run_backward(
        | [rank7]:   File "/tmp/workspace/cuda/results/venv/torch/lib/python3.10/site-packages/torch/autograd/graph.py", line 768, in _engine_run_backward
        | [rank7]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
        | [rank7]:   File "/tmp/workspace/cuda/results/venv/torch/lib/python3.10/site-packages/torch/autograd/function.py", line 306, in apply
        | [rank7]:     return user_fn(self, *args)
        | [rank7]:   File "/tmp/workspace/cuda/results/venv/torch/lib/python3.10/site-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 543, in backward
        | [rank7]:     ctx.param_group.post_backward()
        | [rank7]:   File "/tmp/workspace/cuda/results/venv/torch/lib/python3.10/site-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 338, in post_backward
        | [rank7]:     self._post_reduce_event, self._partial_reduce_output = foreach_reduce(
        | [rank7]:   File "/tmp/workspace/cuda/results/venv/torch/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
        | [rank7]:     return func(*args, **kwargs)
        | [rank7]:   File "/tmp/workspace/cuda/results/venv/torch/lib/python3.10/site-packages/torch/distributed/_composable/fsdp/_fsdp_collectives.py", line 245, in foreach_reduce
        | [rank7]:     fsdp_param.sharded_param.grad = new_sharded_dtensor_grad
        | [rank7]: RuntimeError: attempting to assign a gradient with device type 'cpu' to a tensor with device type 'cuda'. Please ensure that the gradient and the tensor are on the same device
  0%|          | 0/3251 [00:55<?, ?it/s]
        | W0830 14:37:46.674000 140676060272448 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 2740966 closing signal SIGTERM
        | W0830 14:37:46.674000 140676060272448 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 2740969 closing signal SIGTERM
        | E0830 14:37:56.528000 140676060272448 torch/distributed/elastic/multiprocessing/api.py:833] failed (exitcode: 1) local_rank: 0 (pid: 2740964) of binary: /tmp/workspace/cuda/results/venv/torch/bin/python
        ```

@weifengpy
Copy link
Contributor

Btw when you tested were you on PyTorch nightlies?

i was testing on a 1-week old pytorch build

@ebsmothers
Copy link
Contributor Author

@Delaunay thanks, let me try to figure out what's happening here. @weifengpy can you share a pip list from your run? I will try to run on your env to see if the error goes away

@weifengpy
Copy link
Contributor

@weifengpy can you share a pip list from your run? I will try to run on your env to see if the error goes away

I just updated my pytorch to latest. Will try again today and share env if it works

@weifengpy
Copy link
Contributor

I can repro in latest pytorch. will debug it today

@weifengpy
Copy link
Contributor

landing torchtune side fix by moving state dict to cpu (and save memory): #1495

optionally, I need to work with Andrew on pytorch side fix so we can use state dict on gpu pytorch/pytorch#135179

@weifengpy
Copy link
Contributor

reopen since it’s automatically closed by torchtune PR

weifengpy added a commit to pytorch/pytorch that referenced this issue Sep 5, 2024
…ading"



`pytest -s distributed/_composable/fsdp/test_fully_shard_training.py -k test_to_float64_after_init`

resolve cpu offload error in TorchTune: pytorch/torchtune#1412

cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this issue Sep 5, 2024
…ading"



`pytest -s distributed/_composable/fsdp/test_fully_shard_training.py -k test_to_float64_after_init`

resolve cpu offload error in TorchTune: pytorch/torchtune#1412

cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this issue Sep 5, 2024
…cal tensors"



`pytest -s distributed/_composable/fsdp/test_fully_shard_training.py -k test_to_float64_after_init`

resolve cpu offload error in TorchTune: pytorch/torchtune#1412

cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this issue Sep 10, 2024
…cal tensors"


resolve cpu offload error in TorchTune: pytorch/torchtune#1412

this PR constructs DTensor from cpu offloaded local tensor

`pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_dp_state_dict_cpu_offload`


cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this issue Sep 10, 2024
…cal tensors"


resolve cpu offload error in TorchTune: pytorch/torchtune#1412

this PR constructs DTensor from cpu offloaded local tensor

`pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_dp_state_dict_cpu_offload`


cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this issue Sep 10, 2024
…cal tensors"


resolve cpu offload error in TorchTune: pytorch/torchtune#1412

this PR constructs DTensor from cpu offloaded local tensor

`pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_dp_state_dict_cpu_offload`


cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this issue Sep 10, 2024
…cal tensors"


resolve cpu offload error in TorchTune: pytorch/torchtune#1412

this PR constructs DTensor from cpu offloaded local tensor

`pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_dp_state_dict_cpu_offload`


cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this issue Sep 10, 2024
…cal tensors"


resolve cpu offload error in TorchTune: pytorch/torchtune#1412

this PR constructs DTensor from cpu offloaded local tensor

`pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_dp_state_dict_cpu_offload`


cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants