Skip to content

Commit

Permalink
Improve distance cutoff in DimeNet (#4562)
Browse files Browse the repository at this point in the history
* Integrate clamping into Envelope

This is cleaner than (dist / cutoff).clamp_max(1.0) because it's the envelope that should become zero and then effectively mask distances larger than the cutoff. This is also how it's implemented in [`DimeNet`](https://github.com/gasteigerjo/dimenet/blob/09123a0e16e728d0a0e53e6686b04f859802aa81/dimenet/model/layers/envelope.py#L23).

* Remove extra parentheses

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update torch_geometric/nn/models/dimenet.py

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Apr 29, 2022
1 parent e220a2c commit 85fc32e
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions torch_geometric/nn/models/dimenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def forward(self, x):
x_pow_p0 = x.pow(p - 1)
x_pow_p1 = x_pow_p0 * x
x_pow_p2 = x_pow_p1 * x
return 1. / x + a * x_pow_p0 + b * x_pow_p1 + c * x_pow_p2
return (1. / x + a * x_pow_p0 + b * x_pow_p1 +
c * x_pow_p2) * (x < 1.0).to(x.dtype)


class BesselBasisLayer(torch.nn.Module):
Expand All @@ -64,7 +65,7 @@ def reset_parameters(self):
self.freq.requires_grad_()

def forward(self, dist):
dist = (dist.unsqueeze(-1) / self.cutoff).clamp(max=1.0)
dist = (dist.unsqueeze(-1) / self.cutoff)
return self.envelope(dist) * (self.freq * dist).sin()


Expand Down

0 comments on commit 85fc32e

Please sign in to comment.