-
Notifications
You must be signed in to change notification settings - Fork 278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
update to gemnet-oc hydra force head to work with amp #825
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! I added a task to generalize this to other models as well. I can't think of a better way to do this at the moment, might want to revisit in BE week.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Had a conversation offline with @misko. Decided it was cleaner to just disable amp in whole head to be consistent with other hydra heads, also the behavior of |
* updated gemnet hydra force head to work with amp * remove use_amp --------- Co-authored-by: Misko <misko@meta.com> Former-commit-id: d972aed74bd68ddbba1ff6f2fa8847f5526e2411
* updated gemnet hydra force head to work with amp * remove use_amp --------- Co-authored-by: Misko <misko@meta.com> Former-commit-id: 041e96fed9db099db18fdd0bc61f54e91c6a5050
Previously if you used
--amp
with gemnet-oc hydra it would fail because of a mismatch in tensor dtypes in the force head (float32 vs float16) at this line.In the original implementation of gemnet-oc the x_F embeddings are float16 when amp is used, which then get cast as float32 at this line. What was happening in the hydra implementation is that when the heads get initialized they are wrapped in amp=False here, so in this line:
x_F = self.out_mlp_F(torch.cat(emb["xs_F"], dim=-1))
self.out_mlp_F
expects float32 whereas the xs_F embeddings are float16. To fix this I added the use_amp property and set the default to true (to be consistent with the original implementation), with that changeself.out_mlp_F
now expects float16 when amp is used and float32 when amp is not used.The problem with this change is that it isn't super transparent what is happening when using amp (out_mpl_F is float16 then later cast to float32). Possibly a better fix would be that
out_mlp_F
should always be float32 i.e. amp is completely disabled in the prediction head. If we go this route would be good to test how this impacts speed, memory, and performance from the original implementation. It may be negligible, in which case just fully disabling amp is more clean in imo.