diff --git a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py index c5e6efb005..c982b7d43a 100644 --- a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py +++ b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py @@ -1489,7 +1489,10 @@ def forward( @registry.register_model("gemnet_oc_force_head") class GemNetOCForceHead(nn.Module, HeadInterface): def __init__( - self, backbone, num_global_out_layers: int, output_init: str = "HeOrthogonal" + self, + backbone, + num_global_out_layers: int, + output_init: str = "HeOrthogonal", ): super().__init__() @@ -1527,9 +1530,9 @@ def forward( self, data: Batch, emb: dict[str, torch.Tensor] ) -> dict[str, torch.Tensor]: if self.direct_forces: - x_F = self.out_mlp_F(torch.cat(emb["xs_F"], dim=-1)) with torch.cuda.amp.autocast(False): - F_st = self.out_forces(x_F.float()) + x_F = self.out_mlp_F(torch.cat(emb["xs_F"], dim=-1).float()) + F_st = self.out_forces(x_F) if self.forces_coupled: # enforce F_st = F_ts nEdges = emb["edge_idx"].shape[0]