From 85fc32ed5a5a4f23551f09fbed1d25645ccf073f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Z=C3=BCgner?= Date: Fri, 29 Apr 2022 11:13:55 +0200 Subject: [PATCH] Improve distance cutoff in `DimeNet` (#4562) * Integrate clamping into Envelope This is cleaner than (dist / cutoff).clamp_max(1.0) because it's the envelope that should become zero and then effectively mask distances larger than the cutoff. This is also how it's implemented in [`DimeNet`](https://github.com/gasteigerjo/dimenet/blob/09123a0e16e728d0a0e53e6686b04f859802aa81/dimenet/model/layers/envelope.py#L23). * Remove extra parentheses * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update torch_geometric/nn/models/dimenet.py Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Matthias Fey --- torch_geometric/nn/models/dimenet.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_geometric/nn/models/dimenet.py b/torch_geometric/nn/models/dimenet.py index 5be16d3f55b0..dd594050dd5d 100644 --- a/torch_geometric/nn/models/dimenet.py +++ b/torch_geometric/nn/models/dimenet.py @@ -45,7 +45,8 @@ def forward(self, x): x_pow_p0 = x.pow(p - 1) x_pow_p1 = x_pow_p0 * x x_pow_p2 = x_pow_p1 * x - return 1. / x + a * x_pow_p0 + b * x_pow_p1 + c * x_pow_p2 + return (1. / x + a * x_pow_p0 + b * x_pow_p1 + + c * x_pow_p2) * (x < 1.0).to(x.dtype) class BesselBasisLayer(torch.nn.Module): @@ -64,7 +65,7 @@ def reset_parameters(self): self.freq.requires_grad_() def forward(self, dist): - dist = (dist.unsqueeze(-1) / self.cutoff).clamp(max=1.0) + dist = (dist.unsqueeze(-1) / self.cutoff) return self.envelope(dist) * (self.freq * dist).sin()