-
Notifications
You must be signed in to change notification settings - Fork 448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Full finetune recipe not working with FSDP2 CPU offload #1412
Comments
cc @weifengpy is there extra initialization logic we need to do for proper CPU offloading support? I know in the LoRA recipe you explicitly call |
I do not expect extra init logic. if we set btw: I tested your command and it works on my machine, if you want to debug at local, a good starting point is figuring out the the fqn for the problematic param. For example, print |
@weifengpy thanks for taking a look. After forward, all params are on CUDA, not CPU. After that the error occurs during backward though (specifically during FSDP's post_backward). Btw when you tested were you on PyTorch nightlies? |
I have the same issue. The official released version on PyPI works, but main does not. The recipe and the config are almost the verbatim version of torchtune own recipe with very light modification for instrumentation, so I doubt the issue is coming from the modification (and it used to work on the released version)
|
i was testing on a 1-week old pytorch build |
@Delaunay thanks, let me try to figure out what's happening here. @weifengpy can you share a pip list from your run? I will try to run on your env to see if the error goes away |
I just updated my pytorch to latest. Will try again today and share env if it works |
I can repro in latest pytorch. will debug it today |
landing torchtune side fix by moving state dict to cpu (and save memory): #1495 optionally, I need to work with Andrew on pytorch side fix so we can use state dict on gpu pytorch/pytorch#135179 |
reopen since it’s automatically closed by torchtune PR |
…ading" `pytest -s distributed/_composable/fsdp/test_fully_shard_training.py -k test_to_float64_after_init` resolve cpu offload error in TorchTune: pytorch/torchtune#1412 cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…ading" `pytest -s distributed/_composable/fsdp/test_fully_shard_training.py -k test_to_float64_after_init` resolve cpu offload error in TorchTune: pytorch/torchtune#1412 cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…cal tensors" `pytest -s distributed/_composable/fsdp/test_fully_shard_training.py -k test_to_float64_after_init` resolve cpu offload error in TorchTune: pytorch/torchtune#1412 cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…cal tensors" resolve cpu offload error in TorchTune: pytorch/torchtune#1412 this PR constructs DTensor from cpu offloaded local tensor `pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_dp_state_dict_cpu_offload` cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…cal tensors" resolve cpu offload error in TorchTune: pytorch/torchtune#1412 this PR constructs DTensor from cpu offloaded local tensor `pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_dp_state_dict_cpu_offload` cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…cal tensors" resolve cpu offload error in TorchTune: pytorch/torchtune#1412 this PR constructs DTensor from cpu offloaded local tensor `pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_dp_state_dict_cpu_offload` cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…cal tensors" resolve cpu offload error in TorchTune: pytorch/torchtune#1412 this PR constructs DTensor from cpu offloaded local tensor `pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_dp_state_dict_cpu_offload` cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…cal tensors" resolve cpu offload error in TorchTune: pytorch/torchtune#1412 this PR constructs DTensor from cpu offloaded local tensor `pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_dp_state_dict_cpu_offload` cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
As in the title.. I spent a bit of time debugging it but haven't figured out the cause yet. E.g. running
gives the error
We aren't passing the CPU offload policy to the top-level sharding of the model here, which also needs to be fixed. But even after this change, the error persists. Probably missing something obvious here..
Once we fix, we need to add a test for CPU offload as well to catch this for next time.
The text was updated successfully, but these errors were encountered: