diff --git a/torch_geometric/nn/models/dimenet.py b/torch_geometric/nn/models/dimenet.py index 5be16d3f55b0..dd594050dd5d 100644 --- a/torch_geometric/nn/models/dimenet.py +++ b/torch_geometric/nn/models/dimenet.py @@ -45,7 +45,8 @@ def forward(self, x): x_pow_p0 = x.pow(p - 1) x_pow_p1 = x_pow_p0 * x x_pow_p2 = x_pow_p1 * x - return 1. / x + a * x_pow_p0 + b * x_pow_p1 + c * x_pow_p2 + return (1. / x + a * x_pow_p0 + b * x_pow_p1 + + c * x_pow_p2) * (x < 1.0).to(x.dtype) class BesselBasisLayer(torch.nn.Module): @@ -64,7 +65,7 @@ def reset_parameters(self): self.freq.requires_grad_() def forward(self, dist): - dist = (dist.unsqueeze(-1) / self.cutoff).clamp(max=1.0) + dist = (dist.unsqueeze(-1) / self.cutoff) return self.envelope(dist) * (self.freq * dist).sin()