Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update to gemnet-oc hydra force head to work with amp #825

Merged
merged 4 commits into from
Sep 5, 2024

Conversation

wood-b
Copy link
Collaborator

@wood-b wood-b commented Sep 2, 2024

Previously if you used --amp with gemnet-oc hydra it would fail because of a mismatch in tensor dtypes in the force head (float32 vs float16) at this line.

In the original implementation of gemnet-oc the x_F embeddings are float16 when amp is used, which then get cast as float32 at this line. What was happening in the hydra implementation is that when the heads get initialized they are wrapped in amp=False here, so in this line: x_F = self.out_mlp_F(torch.cat(emb["xs_F"], dim=-1)) self.out_mlp_F expects float32 whereas the xs_F embeddings are float16. To fix this I added the use_amp property and set the default to true (to be consistent with the original implementation), with that change self.out_mlp_F now expects float16 when amp is used and float32 when amp is not used.

The problem with this change is that it isn't super transparent what is happening when using amp (out_mpl_F is float16 then later cast to float32). Possibly a better fix would be that out_mlp_F should always be float32 i.e. amp is completely disabled in the prediction head. If we go this route would be good to test how this impacts speed, memory, and performance from the original implementation. It may be negligible, in which case just fully disabling amp is more clean in imo.

@wood-b wood-b requested a review from misko September 2, 2024 20:57
@wood-b wood-b added bug Something isn't working patch Patch version release labels Sep 2, 2024
Copy link

codecov bot commented Sep 2, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Files with missing lines Coverage Δ
src/fairchem/core/models/gemnet_oc/gemnet_oc.py 89.91% <100.00%> (ø)

misko
misko previously approved these changes Sep 3, 2024
Copy link
Collaborator

@misko misko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! I added a task to generalize this to other models as well. I can't think of a better way to do this at the moment, might want to revisit in BE week.

@wood-b wood-b requested a review from rayg1234 September 4, 2024 16:40
Copy link
Collaborator

@misko misko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@wood-b
Copy link
Collaborator Author

wood-b commented Sep 4, 2024

Had a conversation offline with @misko. Decided it was cleaner to just disable amp in whole head to be consistent with other hydra heads, also the behavior of use_amp was not super clear.

@misko misko self-requested a review September 4, 2024 23:40
@misko misko enabled auto-merge September 5, 2024 00:11
@misko misko added this pull request to the merge queue Sep 5, 2024
Merged via the queue into main with commit c8e87c1 Sep 5, 2024
8 checks passed
@misko misko deleted the gemnet_force_hydra branch September 5, 2024 01:41
misko added a commit that referenced this pull request Jan 17, 2025
* updated gemnet hydra force head to work with amp

* remove use_amp

---------

Co-authored-by: Misko <misko@meta.com>
Former-commit-id: d972aed74bd68ddbba1ff6f2fa8847f5526e2411
beomseok-kang pushed a commit to beomseok-kang/fairchem that referenced this pull request Jan 27, 2025
* updated gemnet hydra force head to work with amp

* remove use_amp

---------

Co-authored-by: Misko <misko@meta.com>
Former-commit-id: 041e96fed9db099db18fdd0bc61f54e91c6a5050
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working patch Patch version release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants