Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt a few torch.utils.checkpoint functions for PyTorch/XLA. #6178

Merged
merged 4 commits into from
Jan 10, 2024

Conversation

ysiraichi
Copy link
Collaborator

Fix: #6086

This PR re-implements get_device_states and set_device_states functions, used in CheckpointFunction, so as to work with PyTorch/XLA. Previously, they weren't a problem, since PyTorch was usually compiled without CUDA support.

cc @JackCaoG @miladm

@JackCaoG
Copy link
Collaborator

lol I copied paste this file mostly from upstream and hoping we can merge back one day. What;s the correct way of extending the checkpoint module do you know?

I copied this file because we need optimization_barrier_

@ysiraichi
Copy link
Collaborator Author

I have no idea. I just fixed the parts where device_module was needed, so that XLA would run successfully.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, through I am not sure how is it different from the upstream one. Can you leave a comment for these two functions?

@ysiraichi ysiraichi force-pushed the ysiraichi/fix-checkpoint branch from 85ceb01 to 3d0d771 Compare January 8, 2024 19:36
@ysiraichi ysiraichi merged commit ebb200b into master Jan 10, 2024
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

test_grad_checkpoint.py fails if PyTorch is compiled with CUDA support.
2 participants