Skip to content

Commit

Permalink
[Code Coverage] models/dimenet.py (#6781)
Browse files Browse the repository at this point in the history
Part of #6528, improves typing and code coverage for DimeNet.

---------

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
SauravMaheshkar and rusty1s authored Feb 24, 2023
1 parent e564ba5 commit d336d13
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 51 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Properly reset the `data_list` cache of an `InMemoryDataset` when accessing `dataset.data` ([#6685](https://github.com/pyg-team/pytorch_geometric/pull/6685))
- Fixed a bug in `Data.subgraph()` and `HeteroData.subgraph()` ([#6613](https://github.com/pyg-team/pytorch_geometric/pull/6613))
- Fixed a bug in `PNAConv` and `DegreeScalerAggregation` to correctly incorporate degree statistics of isolated nodes ([#6609](https://github.com/pyg-team/pytorch_geometric/pull/6609))
- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6653](https://github.com/pyg-team/pytorch_geometric/pull/6653), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6667](https://github.com/pyg-team/pytorch_geometric/pull/6667), [#6668](https://github.com/pyg-team/pytorch_geometric/pull/6668), [#6669](https://github.com/pyg-team/pytorch_geometric/pull/6669), [#6670](https://github.com/pyg-team/pytorch_geometric/pull/6670), [#6671](https://github.com/pyg-team/pytorch_geometric/pull/6671), [#6673](https://github.com/pyg-team/pytorch_geometric/pull/6673), [#6675](https://github.com/pyg-team/pytorch_geometric/pull/6675), [#6676](https://github.com/pyg-team/pytorch_geometric/pull/6676), [#6677](https://github.com/pyg-team/pytorch_geometric/pull/6677), [#6678](https://github.com/pyg-team/pytorch_geometric/pull/6678), [#6681](https://github.com/pyg-team/pytorch_geometric/pull/6681), [#6683](https://github.com/pyg-team/pytorch_geometric/pull/6683), [#6703](https://github.com/pyg-team/pytorch_geometric/pull/6703), [#6720](https://github.com/pyg-team/pytorch_geometric/pull/6720), [#6735](https://github.com/pyg-team/pytorch_geometric/pull/6735), [#6736](https://github.com/pyg-team/pytorch_geometric/pull/6736), [#6763](https://github.com/pyg-team/pytorch_geometric/pull/6763))
- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6653](https://github.com/pyg-team/pytorch_geometric/pull/6653), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6667](https://github.com/pyg-team/pytorch_geometric/pull/6667), [#6668](https://github.com/pyg-team/pytorch_geometric/pull/6668), [#6669](https://github.com/pyg-team/pytorch_geometric/pull/6669), [#6670](https://github.com/pyg-team/pytorch_geometric/pull/6670), [#6671](https://github.com/pyg-team/pytorch_geometric/pull/6671), [#6673](https://github.com/pyg-team/pytorch_geometric/pull/6673), [#6675](https://github.com/pyg-team/pytorch_geometric/pull/6675), [#6676](https://github.com/pyg-team/pytorch_geometric/pull/6676), [#6677](https://github.com/pyg-team/pytorch_geometric/pull/6677), [#6678](https://github.com/pyg-team/pytorch_geometric/pull/6678), [#6681](https://github.com/pyg-team/pytorch_geometric/pull/6681), [#6683](https://github.com/pyg-team/pytorch_geometric/pull/6683), [#6703](https://github.com/pyg-team/pytorch_geometric/pull/6703), [#6720](https://github.com/pyg-team/pytorch_geometric/pull/6720), [#6735](https://github.com/pyg-team/pytorch_geometric/pull/6735), [#6736](https://github.com/pyg-team/pytorch_geometric/pull/6736), [#6763](https://github.com/pyg-team/pytorch_geometric/pull/6763), [#6781](https://github.com/pyg-team/pytorch_geometric/pull/6781))
- Fixed a bug in which `data.to_heterogeneous()` filtered attributs in the wrong dimension ([#6522](https://github.com/pyg-team/pytorch_geometric/pull/6522))
- Breaking Change: Temporal sampling will now also sample nodes with an equal timestamp to the seed time (requires `pyg-lib>0.1.0`) ([#6517](https://github.com/pyg-team/pytorch_geometric/pull/6517))
- Changed `DataLoader` workers with affinity to start at `cpu0` ([#6512](https://github.com/pyg-team/pytorch_geometric/pull/6512))
Expand Down
62 changes: 42 additions & 20 deletions test/nn/models/test_dimenet.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,47 @@
import pytest
import torch
import torch.nn.functional as F

from torch_geometric.nn import DimeNetPlusPlus
from torch_geometric.testing import onlyFullTest
from torch_geometric.nn import DimeNet, DimeNetPlusPlus
from torch_geometric.nn.models.dimenet import (
BesselBasisLayer,
Envelope,
ResidualLayer,
)
from torch_geometric.testing import is_full_test


@onlyFullTest
def test_dimenet_plus_plus():
def test_dimenet_modules():
env = Envelope(exponent=5)
x = torch.randn(10, 3)
assert env(x).size() == (10, 3) # Isotonic layer.

bbl = BesselBasisLayer(5)
x = torch.randn(10, 3)
assert bbl(x).size() == (10, 3, 5) # Non-isotonic layer.

rl = ResidualLayer(128, torch.nn.functional.relu)
x = torch.randn(128, 128)
assert rl(x).size() == (128, 128) # Isotonic layer.


@pytest.mark.parametrize('Model', [DimeNet, DimeNetPlusPlus])
def test_dimenet(Model):
z = torch.randint(1, 10, (20, ))
pos = torch.randn(20, 3)

model = DimeNetPlusPlus(
if Model == DimeNet:
kwargs = dict(num_bilinear=3)
else:
kwargs = dict(out_emb_channels=3, int_emb_size=5, basis_emb_size=5)

model = Model(
hidden_channels=5,
out_channels=1,
num_blocks=5,
out_emb_channels=3,
int_emb_size=5,
basis_emb_size=5,
num_spherical=5,
num_radial=5,
num_before_skip=2,
num_after_skip=2,
**kwargs,
)
model.reset_parameters()

Expand All @@ -31,14 +52,15 @@ def test_dimenet_plus_plus():
jit = torch.jit.export(model)
assert torch.allclose(jit(z, pos), out)

optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
if is_full_test():
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

min_loss = float('inf')
for i in range(100):
optimizer.zero_grad()
out = model(z, pos)
loss = F.l1_loss(out, torch.tensor([1.]))
loss.backward()
optimizer.step()
min_loss = min(float(loss), min_loss)
assert min_loss < 2
min_loss = float('inf')
for i in range(100):
optimizer.zero_grad()
out = model(z, pos)
loss = F.l1_loss(out, torch.tensor([1.0]))
loss.backward()
optimizer.step()
min_loss = min(float(loss), min_loss)
assert min_loss < 2
111 changes: 81 additions & 30 deletions torch_geometric/nn/models/dimenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os.path as osp
from math import pi as PI
from math import sqrt
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Dict, Optional, Tuple, Union

import numpy as np
import torch
Expand All @@ -17,7 +17,7 @@
from torch_geometric.typing import OptTensor, SparseTensor
from torch_geometric.utils import scatter

qm9_target_dict = {
qm9_target_dict: Dict[int, str] = {
0: 'mu',
1: 'alpha',
2: 'homo',
Expand Down Expand Up @@ -45,7 +45,7 @@ def forward(self, x: Tensor) -> Tensor:
x_pow_p0 = x.pow(p - 1)
x_pow_p1 = x_pow_p0 * x
x_pow_p2 = x_pow_p1 * x
return (1. / x + a * x_pow_p0 + b * x_pow_p1 +
return (1.0 / x + a * x_pow_p0 + b * x_pow_p1 +
c * x_pow_p2) * (x < 1.0).to(x.dtype)


Expand All @@ -66,13 +66,18 @@ def reset_parameters(self):
self.freq.requires_grad_()

def forward(self, dist: Tensor) -> Tensor:
dist = (dist.unsqueeze(-1) / self.cutoff)
dist = dist.unsqueeze(-1) / self.cutoff
return self.envelope(dist) * (self.freq * dist).sin()


class SphericalBasisLayer(torch.nn.Module):
def __init__(self, num_spherical: int, num_radial: int,
cutoff: float = 5.0, envelope_exponent: int = 5):
def __init__(
self,
num_spherical: int,
num_radial: int,
cutoff: float = 5.0,
envelope_exponent: int = 5,
):
super().__init__()
import sympy as sym

Expand Down Expand Up @@ -159,9 +164,16 @@ def forward(self, x: Tensor) -> Tensor:


class InteractionBlock(torch.nn.Module):
def __init__(self, hidden_channels: int, num_bilinear: int,
num_spherical: int, num_radial: int, num_before_skip: int,
num_after_skip: int, act: Callable):
def __init__(
self,
hidden_channels: int,
num_bilinear: int,
num_spherical: int,
num_radial: int,
num_before_skip: int,
num_after_skip: int,
act: Callable,
):
super().__init__()
self.act = act

Expand Down Expand Up @@ -223,9 +235,17 @@ def forward(self, x: Tensor, rbf: Tensor, sbf: Tensor, idx_kj: Tensor,


class InteractionPPBlock(torch.nn.Module):
def __init__(self, hidden_channels: int, int_emb_size: int,
basis_emb_size: int, num_spherical: int, num_radial: int,
num_before_skip: int, num_after_skip: int, act: Callable):
def __init__(
self,
hidden_channels: int,
int_emb_size: int,
basis_emb_size: int,
num_spherical: int,
num_radial: int,
num_before_skip: int,
num_after_skip: int,
act: Callable,
):
super().__init__()
self.act = act

Expand Down Expand Up @@ -311,8 +331,14 @@ def forward(self, x: Tensor, rbf: Tensor, sbf: Tensor, idx_kj: Tensor,


class OutputBlock(torch.nn.Module):
def __init__(self, num_radial: int, hidden_channels: int,
out_channels: int, num_layers: int, act: Callable):
def __init__(
self,
num_radial: int,
hidden_channels: int,
out_channels: int,
num_layers: int,
act: Callable,
):
super().__init__()
self.act = act

Expand Down Expand Up @@ -341,9 +367,15 @@ def forward(self, x: Tensor, rbf: Tensor, i: Tensor,


class OutputPPBlock(torch.nn.Module):
def __init__(self, num_radial: int, hidden_channels: int,
out_emb_channels: int, out_channels: int, num_layers: int,
act: Callable):
def __init__(
self,
num_radial: int,
hidden_channels: int,
out_emb_channels: int,
out_channels: int,
num_layers: int,
act: Callable,
):
super().__init__()
self.act = act

Expand Down Expand Up @@ -450,7 +482,7 @@ def __init__(
num_blocks: int,
num_bilinear: int,
num_spherical: int,
num_radial,
num_radial: int,
cutoff: float = 5.0,
max_num_neighbors: int = 32,
envelope_exponent: int = 5,
Expand All @@ -462,7 +494,7 @@ def __init__(
super().__init__()

if num_spherical < 2:
raise ValueError("num_spherical should be greater than 1")
raise ValueError("'num_spherical' should be greater than 1")

act = activation_resolver(act)

Expand All @@ -482,9 +514,15 @@ def __init__(
])

self.interaction_blocks = torch.nn.ModuleList([
InteractionBlock(hidden_channels, num_bilinear, num_spherical,
num_radial, num_before_skip, num_after_skip, act)
for _ in range(num_blocks)
InteractionBlock(
hidden_channels,
num_bilinear,
num_spherical,
num_radial,
num_before_skip,
num_after_skip,
act,
) for _ in range(num_blocks)
])

def reset_parameters(self):
Expand All @@ -502,7 +540,7 @@ def from_qm9_pretrained(
root: str,
dataset: Dataset,
target: int,
) -> Tuple['DimeNet', Dataset, Dataset, Dataset]:
) -> Tuple['DimeNet', Dataset, Dataset, Dataset]: # pragma: no cover
r"""Returns a pre-trained :class:`DimeNet` model on the
:class:`~torch_geometric.datasets.QM9` dataset, trained on the
specified target :obj:`target`."""
Expand Down Expand Up @@ -729,15 +767,27 @@ def __init__(
# variable `num_bilinear` does not have any purpose as it is used
# solely in the `OutputBlock` of DimeNet:
self.output_blocks = torch.nn.ModuleList([
OutputPPBlock(num_radial, hidden_channels, out_emb_channels,
out_channels, num_output_layers, act)
for _ in range(num_blocks + 1)
OutputPPBlock(
num_radial,
hidden_channels,
out_emb_channels,
out_channels,
num_output_layers,
act,
) for _ in range(num_blocks + 1)
])

self.interaction_blocks = torch.nn.ModuleList([
InteractionPPBlock(hidden_channels, int_emb_size, basis_emb_size,
num_spherical, num_radial, num_before_skip,
num_after_skip, act) for _ in range(num_blocks)
InteractionPPBlock(
hidden_channels,
int_emb_size,
basis_emb_size,
num_spherical,
num_radial,
num_before_skip,
num_after_skip,
act,
) for _ in range(num_blocks)
])

self.reset_parameters()
Expand All @@ -748,7 +798,8 @@ def from_qm9_pretrained(
root: str,
dataset: Dataset,
target: int,
) -> Tuple['DimeNetPlusPlus', Dataset, Dataset, Dataset]:
) -> Tuple['DimeNetPlusPlus', Dataset, Dataset,
Dataset]: # pragma: no cover
r"""Returns a pre-trained :class:`DimeNetPlusPlus` model on the
:class:`~torch_geometric.datasets.QM9` dataset, trained on the
specified target :obj:`target`."""
Expand Down

0 comments on commit d336d13

Please sign in to comment.