Skip to content

Commit

Permalink
DimeNet: Added swish activation to activation_resolver (#4700)
Browse files Browse the repository at this point in the history
* add dimenet++ readme

* changelog

* swish act

* update
  • Loading branch information
rusty1s authored May 24, 2022
1 parent f3ce4f2 commit c7062dc
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 19 deletions.
2 changes: 0 additions & 2 deletions torch_geometric/nn/acts.py

This file was deleted.

36 changes: 20 additions & 16 deletions torch_geometric/nn/models/dimenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os.path as osp
from math import pi as PI
from math import sqrt
from typing import Callable
from typing import Callable, Union

import numpy as np
import torch
Expand All @@ -13,9 +13,8 @@
from torch_geometric.data import Dataset, download_url
from torch_geometric.data.makedirs import makedirs
from torch_geometric.nn import radius_graph

from ..acts import swish
from ..inits import glorot_orthogonal
from torch_geometric.nn.inits import glorot_orthogonal
from torch_geometric.nn.resolver import activation_resolver

qm9_target_dict = {
0: 'mu',
Expand Down Expand Up @@ -117,7 +116,7 @@ def forward(self, dist, angle, idx_kj):


class EmbeddingBlock(torch.nn.Module):
def __init__(self, num_radial, hidden_channels, act=swish):
def __init__(self, num_radial, hidden_channels, act):
super().__init__()
self.act = act

Expand All @@ -139,7 +138,7 @@ def forward(self, x, rbf, i, j):


class ResidualLayer(torch.nn.Module):
def __init__(self, hidden_channels, act=swish):
def __init__(self, hidden_channels, act):
super().__init__()
self.act = act
self.lin1 = Linear(hidden_channels, hidden_channels)
Expand All @@ -159,7 +158,7 @@ def forward(self, x):

class InteractionBlock(torch.nn.Module):
def __init__(self, hidden_channels, num_bilinear, num_spherical,
num_radial, num_before_skip, num_after_skip, act=swish):
num_radial, num_before_skip, num_after_skip, act):
super().__init__()
self.act = act

Expand Down Expand Up @@ -222,7 +221,7 @@ def forward(self, x, rbf, sbf, idx_kj, idx_ji):
class InteractionPPBlock(torch.nn.Module):
def __init__(self, hidden_channels, int_emb_size, basis_emb_size,
num_spherical, num_radial, num_before_skip, num_after_skip,
act=swish):
act):
super().__init__()
self.act = act

Expand Down Expand Up @@ -308,7 +307,7 @@ def forward(self, x, rbf, sbf, idx_kj, idx_ji):

class OutputBlock(torch.nn.Module):
def __init__(self, num_radial, hidden_channels, out_channels, num_layers,
act=swish):
act):
super().__init__()
self.act = act

Expand Down Expand Up @@ -337,7 +336,7 @@ def forward(self, x, rbf, i, num_nodes=None):

class OutputPPBlock(torch.nn.Module):
def __init__(self, num_radial, hidden_channels, out_emb_channels,
out_channels, num_layers, act=swish):
out_channels, num_layers, act):
super().__init__()
self.act = act

Expand Down Expand Up @@ -403,8 +402,8 @@ class DimeNet(torch.nn.Module):
interaction blocks after the skip connection. (default: :obj:`2`)
num_output_layers (int, optional): Number of linear layers for the
output blocks. (default: :obj:`3`)
act (Callable, optional): The activation function.
(default: :obj:`swish`)
act (str or Callable, optional): The activation function.
(default: :obj:`"swish"`)
"""

url = ('https://github.com/klicperajo/dimenet/raw/master/pretrained/'
Expand All @@ -415,12 +414,14 @@ def __init__(self, hidden_channels: int, out_channels: int,
num_radial, cutoff: float = 5.0, max_num_neighbors: int = 32,
envelope_exponent: int = 5, num_before_skip: int = 1,
num_after_skip: int = 2, num_output_layers: int = 3,
act: Callable = swish):
act: Union[str, Callable] = 'swish'):
super().__init__()

if num_spherical < 2:
raise ValueError("num_spherical should be greater than 1")

act = activation_resolver(act)

self.cutoff = cutoff
self.max_num_neighbors = max_num_neighbors
self.num_blocks = num_blocks
Expand Down Expand Up @@ -633,8 +634,8 @@ class DimeNetPlusPlus(DimeNet):
interaction blocks after the skip connection. (default: :obj:`2`)
num_output_layers: (int, optional): Number of linear layers for the
output blocks. (default: :obj:`3`)
act: (Callable, optional): The activation funtion.
(default: :obj:`swish`)
act: (str or Callable, optional): The activation funtion.
(default: :obj:`"swish"`)
"""

url = ('https://raw.githubusercontent.com/gasteigerjo/dimenet/'
Expand All @@ -646,7 +647,10 @@ def __init__(self, hidden_channels: int, out_channels: int,
cutoff: float = 5.0, max_num_neighbors: int = 32,
envelope_exponent: int = 5, num_before_skip: int = 1,
num_after_skip: int = 2, num_output_layers: int = 3,
act: Callable = swish):
act: Union[str, Callable] = 'swish'):

act = activation_resolver(act)

super().__init__(
hidden_channels=hidden_channels,
out_channels=out_channels,
Expand Down
17 changes: 16 additions & 1 deletion torch_geometric/nn/resolver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import inspect
from typing import Any, List, Union

import torch
from torch import Tensor


def normalize_string(s: str) -> str:
Expand All @@ -14,16 +16,29 @@ def resolver(classes: List[Any], query: Union[Any, str], *args, **kwargs):
query = normalize_string(query)
for cls in classes:
if query == normalize_string(cls.__name__):
return cls(*args, **kwargs)
if inspect.isclass(cls):
return cls(*args, **kwargs)
else:
return cls

return ValueError(
f"Could not resolve '{query}' among the choices "
f"{set(normalize_string(cls.__name__) for cls in classes)}")


# Activation Resolver #########################################################


def swish(x: Tensor) -> Tensor:
return x * x.sigmoid()


def activation_resolver(query: Union[Any, str] = 'relu', *args, **kwargs):
acts = [
act for act in vars(torch.nn.modules.activation).values()
if isinstance(act, type) and issubclass(act, torch.nn.Module)
]
acts += [
swish,
]
return resolver(acts, query, *args, **kwargs)

0 comments on commit c7062dc

Please sign in to comment.