From c7062dc95183cc9f13b517dc64b5376114f9b724 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Tue, 24 May 2022 08:12:39 +0200 Subject: [PATCH] `DimeNet`: Added `swish` activation to `activation_resolver` (#4700) * add dimenet++ readme * changelog * swish act * update --- torch_geometric/nn/acts.py | 2 -- torch_geometric/nn/models/dimenet.py | 36 +++++++++++++++------------- torch_geometric/nn/resolver.py | 17 ++++++++++++- 3 files changed, 36 insertions(+), 19 deletions(-) delete mode 100644 torch_geometric/nn/acts.py diff --git a/torch_geometric/nn/acts.py b/torch_geometric/nn/acts.py deleted file mode 100644 index 234ccefa35c0..000000000000 --- a/torch_geometric/nn/acts.py +++ /dev/null @@ -1,2 +0,0 @@ -def swish(x): - return x * x.sigmoid() diff --git a/torch_geometric/nn/models/dimenet.py b/torch_geometric/nn/models/dimenet.py index 327ec2fd19b7..4e3f44169b04 100644 --- a/torch_geometric/nn/models/dimenet.py +++ b/torch_geometric/nn/models/dimenet.py @@ -2,7 +2,7 @@ import os.path as osp from math import pi as PI from math import sqrt -from typing import Callable +from typing import Callable, Union import numpy as np import torch @@ -13,9 +13,8 @@ from torch_geometric.data import Dataset, download_url from torch_geometric.data.makedirs import makedirs from torch_geometric.nn import radius_graph - -from ..acts import swish -from ..inits import glorot_orthogonal +from torch_geometric.nn.inits import glorot_orthogonal +from torch_geometric.nn.resolver import activation_resolver qm9_target_dict = { 0: 'mu', @@ -117,7 +116,7 @@ def forward(self, dist, angle, idx_kj): class EmbeddingBlock(torch.nn.Module): - def __init__(self, num_radial, hidden_channels, act=swish): + def __init__(self, num_radial, hidden_channels, act): super().__init__() self.act = act @@ -139,7 +138,7 @@ def forward(self, x, rbf, i, j): class ResidualLayer(torch.nn.Module): - def __init__(self, hidden_channels, act=swish): + def __init__(self, hidden_channels, act): super().__init__() self.act = act self.lin1 = Linear(hidden_channels, hidden_channels) @@ -159,7 +158,7 @@ def forward(self, x): class InteractionBlock(torch.nn.Module): def __init__(self, hidden_channels, num_bilinear, num_spherical, - num_radial, num_before_skip, num_after_skip, act=swish): + num_radial, num_before_skip, num_after_skip, act): super().__init__() self.act = act @@ -222,7 +221,7 @@ def forward(self, x, rbf, sbf, idx_kj, idx_ji): class InteractionPPBlock(torch.nn.Module): def __init__(self, hidden_channels, int_emb_size, basis_emb_size, num_spherical, num_radial, num_before_skip, num_after_skip, - act=swish): + act): super().__init__() self.act = act @@ -308,7 +307,7 @@ def forward(self, x, rbf, sbf, idx_kj, idx_ji): class OutputBlock(torch.nn.Module): def __init__(self, num_radial, hidden_channels, out_channels, num_layers, - act=swish): + act): super().__init__() self.act = act @@ -337,7 +336,7 @@ def forward(self, x, rbf, i, num_nodes=None): class OutputPPBlock(torch.nn.Module): def __init__(self, num_radial, hidden_channels, out_emb_channels, - out_channels, num_layers, act=swish): + out_channels, num_layers, act): super().__init__() self.act = act @@ -403,8 +402,8 @@ class DimeNet(torch.nn.Module): interaction blocks after the skip connection. (default: :obj:`2`) num_output_layers (int, optional): Number of linear layers for the output blocks. (default: :obj:`3`) - act (Callable, optional): The activation function. - (default: :obj:`swish`) + act (str or Callable, optional): The activation function. + (default: :obj:`"swish"`) """ url = ('https://github.com/klicperajo/dimenet/raw/master/pretrained/' @@ -415,12 +414,14 @@ def __init__(self, hidden_channels: int, out_channels: int, num_radial, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, - act: Callable = swish): + act: Union[str, Callable] = 'swish'): super().__init__() if num_spherical < 2: raise ValueError("num_spherical should be greater than 1") + act = activation_resolver(act) + self.cutoff = cutoff self.max_num_neighbors = max_num_neighbors self.num_blocks = num_blocks @@ -633,8 +634,8 @@ class DimeNetPlusPlus(DimeNet): interaction blocks after the skip connection. (default: :obj:`2`) num_output_layers: (int, optional): Number of linear layers for the output blocks. (default: :obj:`3`) - act: (Callable, optional): The activation funtion. - (default: :obj:`swish`) + act: (str or Callable, optional): The activation funtion. + (default: :obj:`"swish"`) """ url = ('https://raw.githubusercontent.com/gasteigerjo/dimenet/' @@ -646,7 +647,10 @@ def __init__(self, hidden_channels: int, out_channels: int, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, - act: Callable = swish): + act: Union[str, Callable] = 'swish'): + + act = activation_resolver(act) + super().__init__( hidden_channels=hidden_channels, out_channels=out_channels, diff --git a/torch_geometric/nn/resolver.py b/torch_geometric/nn/resolver.py index 04c2374f9114..13ea9119eaec 100644 --- a/torch_geometric/nn/resolver.py +++ b/torch_geometric/nn/resolver.py @@ -1,6 +1,8 @@ +import inspect from typing import Any, List, Union import torch +from torch import Tensor def normalize_string(s: str) -> str: @@ -14,16 +16,29 @@ def resolver(classes: List[Any], query: Union[Any, str], *args, **kwargs): query = normalize_string(query) for cls in classes: if query == normalize_string(cls.__name__): - return cls(*args, **kwargs) + if inspect.isclass(cls): + return cls(*args, **kwargs) + else: + return cls return ValueError( f"Could not resolve '{query}' among the choices " f"{set(normalize_string(cls.__name__) for cls in classes)}") +# Activation Resolver ######################################################### + + +def swish(x: Tensor) -> Tensor: + return x * x.sigmoid() + + def activation_resolver(query: Union[Any, str] = 'relu', *args, **kwargs): acts = [ act for act in vars(torch.nn.modules.activation).values() if isinstance(act, type) and issubclass(act, torch.nn.Module) ] + acts += [ + swish, + ] return resolver(acts, query, *args, **kwargs)