Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update amp custom_fwd, custom_bwd usage for torch 2.4.0 compatibility #54

Merged

Conversation

mirceamironenco
Copy link
Contributor

In torch 2.4.0 the autocast APIs were unified under torch.amp. Usage of torch.{device}.amp will produce a deprecation warning: FutureWarning: torch.cuda.amp.custom_fwd(args...) is deprecated. Please use torch.amp.custom_fwd(args..., device_type='cuda') instead.

This PR gates custom_fwd/custom_bwd by version, so that < 2.4 torch is still usable. Since these ops are used for a lot of primitives I've moved them to fla.utils and renamed them for clarity. This should still allow future support for device types other than cuda.

@yzhangcs yzhangcs merged commit fd29964 into fla-org:main Aug 25, 2024
1 check failed
@yzhangcs
Copy link
Member

Thank you for the contributions.

@mirceamironenco mirceamironenco deleted the fix-torch2.4.0-amp-futurewarning branch September 20, 2024 07:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants