diff --git a/CHANGELOG.md b/CHANGELOG.md index 26238f529e92..06ce456eb1e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755)) - Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926)) - Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723)) -- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762)) +- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749)) - Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700)) - Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487)) - Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730)) diff --git a/test/nn/aggr/test_basic.py b/test/nn/aggr/test_basic.py index 24227ed25488..f3dc0c51964f 100644 --- a/test/nn/aggr/test_basic.py +++ b/test/nn/aggr/test_basic.py @@ -5,6 +5,7 @@ MaxAggregation, MeanAggregation, MinAggregation, + MulAggregation, PowerMeanAggregation, SoftmaxAggregation, StdAggregation, @@ -29,7 +30,7 @@ def test_validate(): @pytest.mark.parametrize('Aggregation', [ MeanAggregation, SumAggregation, MaxAggregation, MinAggregation, - VarAggregation, StdAggregation + MulAggregation, VarAggregation, StdAggregation ]) def test_basic_aggregation(Aggregation): x = torch.randn(6, 16) @@ -41,7 +42,12 @@ def test_basic_aggregation(Aggregation): out = aggr(x, index) assert out.size() == (3, x.size(1)) - assert torch.allclose(out, aggr(x, ptr=ptr)) + + if isinstance(aggr, MulAggregation): + with pytest.raises(NotImplementedError, match="requires 'index'"): + aggr(x, ptr=ptr) + else: + assert torch.allclose(out, aggr(x, ptr=ptr)) @pytest.mark.parametrize('Aggregation', @@ -53,7 +59,7 @@ def test_gen_aggregation(Aggregation, learn): ptr = torch.tensor([0, 2, 5, 6]) aggr = Aggregation(learn=learn) - assert str(aggr) == f'{Aggregation.__name__}()' + assert str(aggr) == f'{Aggregation.__name__}(learn={learn})' out = aggr(x, index) assert out.size() == (3, x.size(1)) diff --git a/test/nn/aggr/test_multi.py b/test/nn/aggr/test_multi.py new file mode 100644 index 000000000000..255ca3de1e09 --- /dev/null +++ b/test/nn/aggr/test_multi.py @@ -0,0 +1,21 @@ +import torch + +from torch_geometric.nn import MultiAggregation + + +def test_multi_aggr(): + x = torch.randn(6, 16) + index = torch.tensor([0, 0, 1, 1, 1, 2]) + ptr = torch.tensor([0, 2, 5, 6]) + + aggrs = ['mean', 'sum', 'max'] + aggr = MultiAggregation(aggrs) + assert str(aggr) == ('MultiAggregation([\n' + ' MeanAggregation(),\n' + ' SumAggregation(),\n' + ' MaxAggregation()\n' + '])') + + out = aggr(x, index) + assert out.size() == (3, len(aggrs) * x.size(1)) + assert torch.allclose(out, aggr(x, ptr=ptr)) diff --git a/test/nn/test_resolver.py b/test/nn/test_resolver.py index 5c67f95749f9..218381a013ec 100644 --- a/test/nn/test_resolver.py +++ b/test/nn/test_resolver.py @@ -1,6 +1,11 @@ +import pytest import torch -from torch_geometric.nn.resolver import activation_resolver +import torch_geometric +from torch_geometric.nn.resolver import ( + activation_resolver, + aggregation_resolver, +) def test_activation_resolver(): @@ -11,3 +16,20 @@ def test_activation_resolver(): assert isinstance(activation_resolver('elu'), torch.nn.ELU) assert isinstance(activation_resolver('relu'), torch.nn.ReLU) assert isinstance(activation_resolver('prelu'), torch.nn.PReLU) + + +@pytest.mark.parametrize('aggr_tuple', [ + (torch_geometric.nn.aggr.MeanAggregation, 'mean'), + (torch_geometric.nn.aggr.SumAggregation, 'sum'), + (torch_geometric.nn.aggr.MaxAggregation, 'max'), + (torch_geometric.nn.aggr.MinAggregation, 'min'), + (torch_geometric.nn.aggr.MulAggregation, 'mul'), + (torch_geometric.nn.aggr.VarAggregation, 'var'), + (torch_geometric.nn.aggr.StdAggregation, 'std'), + (torch_geometric.nn.aggr.SoftmaxAggregation, 'softmax'), + (torch_geometric.nn.aggr.PowerMeanAggregation, 'powermean'), +]) +def test_aggregation_resolver(aggr_tuple): + aggr_module, aggr_repr = aggr_tuple + assert isinstance(aggregation_resolver(aggr_module()), aggr_module) + assert isinstance(aggregation_resolver(aggr_repr), aggr_module) diff --git a/torch_geometric/nn/aggr/__init__.py b/torch_geometric/nn/aggr/__init__.py index adc1e7b67e85..11934d179090 100644 --- a/torch_geometric/nn/aggr/__init__.py +++ b/torch_geometric/nn/aggr/__init__.py @@ -1,9 +1,12 @@ from .base import Aggregation +from .multi import MultiAggregation from .basic import ( MeanAggregation, SumAggregation, + AddAggregation, MaxAggregation, MinAggregation, + MulAggregation, VarAggregation, StdAggregation, SoftmaxAggregation, @@ -14,10 +17,13 @@ __all__ = classes = [ 'Aggregation', + 'MultiAggregation', 'MeanAggregation', 'SumAggregation', + 'AddAggregation', 'MaxAggregation', 'MinAggregation', + 'MulAggregation', 'VarAggregation', 'StdAggregation', 'SoftmaxAggregation', diff --git a/torch_geometric/nn/aggr/basic.py b/torch_geometric/nn/aggr/basic.py index 3b52fc225fad..1e5adf80fa36 100644 --- a/torch_geometric/nn/aggr/basic.py +++ b/torch_geometric/nn/aggr/basic.py @@ -22,6 +22,9 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *, return self.reduce(x, index, ptr, dim_size, dim, reduce='sum') +AddAggregation = SumAggregation # Alias + + class MaxAggregation(Aggregation): def forward(self, x: Tensor, index: Optional[Tensor] = None, *, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, @@ -36,6 +39,15 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *, return self.reduce(x, index, ptr, dim_size, dim, reduce='min') +class MulAggregation(Aggregation): + def forward(self, x: Tensor, index: Optional[Tensor] = None, *, + ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, + dim: int = -2) -> Tensor: + # TODO Currently, `mul` reduction can only operate on `index`: + self.assert_index_present(index) + return self.reduce(x, index, None, dim_size, dim, reduce='mul') + + class VarAggregation(Aggregation): def forward(self, x: Tensor, index: Optional[Tensor] = None, *, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, @@ -61,6 +73,7 @@ def __init__(self, t: float = 1.0, learn: bool = False): super().__init__() self._init_t = t self.t = Parameter(torch.Tensor(1)) if learn else t + self.learn = learn self.reset_parameters() def reset_parameters(self): @@ -77,6 +90,9 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *, alpha = softmax(alpha, index, ptr, dim_size, dim) return self.reduce(x * alpha, index, ptr, dim_size, dim, reduce='sum') + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(learn={self.learn})') + class PowerMeanAggregation(Aggregation): def __init__(self, p: float = 1.0, learn: bool = False): @@ -84,8 +100,10 @@ def __init__(self, p: float = 1.0, learn: bool = False): super().__init__() self._init_p = p self.p = Parameter(torch.Tensor(1)) if learn else p + self.learn = learn self.reset_parameters() + def reset_parameters(self): if isinstance(self.p, Tensor): self.p.data.fill_(self._init_p) @@ -97,3 +115,6 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *, if isinstance(self.p, (int, float)) and self.p == 1: return out return out.clamp_(min=0, max=100).pow(1. / self.p) + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(learn={self.learn})') diff --git a/torch_geometric/nn/aggr/multi.py b/torch_geometric/nn/aggr/multi.py new file mode 100644 index 000000000000..97b2c713ba12 --- /dev/null +++ b/torch_geometric/nn/aggr/multi.py @@ -0,0 +1,34 @@ +from typing import List, Optional, Union + +import torch +from torch import Tensor + +from torch_geometric.nn.aggr import Aggregation +from torch_geometric.nn.resolver import aggregation_resolver + + +class MultiAggregation(Aggregation): + def __init__(self, aggrs: List[Union[Aggregation, str]]): + super().__init__() + + if not isinstance(aggrs, (list, tuple)): + raise ValueError(f"'aggrs' of '{self.__class__.__name__}' should " + f"be a list or tuple (got '{type(aggrs)}')") + + if len(aggrs) == 0: + raise ValueError(f"'aggrs' of '{self.__class__.__name__}' should " + f"not be empty") + + self.aggrs = [aggregation_resolver(aggr) for aggr in aggrs] + + def forward(self, x: Tensor, index: Optional[Tensor] = None, *, + ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, + dim: int = -2) -> Tensor: + outs = [] + for aggr in self.aggrs: + outs.append(aggr(x, index, ptr=ptr, dim_size=dim_size, dim=dim)) + return torch.cat(outs, dim=-1) if len(outs) > 1 else outs[0] + + def __repr__(self) -> str: + args = [f' {aggr}' for aggr in self.aggrs] + return '{}([\n{}\n])'.format(self.__class__.__name__, ',\n'.join(args)) diff --git a/torch_geometric/nn/resolver.py b/torch_geometric/nn/resolver.py index 13ea9119eaec..8d5e16ebccfc 100644 --- a/torch_geometric/nn/resolver.py +++ b/torch_geometric/nn/resolver.py @@ -1,7 +1,6 @@ import inspect -from typing import Any, List, Union +from typing import Any, List, Optional, Union -import torch from torch import Tensor @@ -9,21 +8,24 @@ def normalize_string(s: str) -> str: return s.lower().replace('-', '').replace('_', '').replace(' ', '') -def resolver(classes: List[Any], query: Union[Any, str], *args, **kwargs): +def resolver(classes: List[Any], query: Union[Any, str], + base_cls: Optional[Any], *args, **kwargs): + if query is None or not isinstance(query, str): return query - query = normalize_string(query) + query_repr = normalize_string(query) + base_cls_repr = normalize_string(base_cls.__name__) if base_cls else '' for cls in classes: - if query == normalize_string(cls.__name__): + cls_repr = normalize_string(cls.__name__) + if query_repr in [cls_repr, cls_repr.replace(base_cls_repr, '')]: if inspect.isclass(cls): return cls(*args, **kwargs) else: return cls - return ValueError( - f"Could not resolve '{query}' among the choices " - f"{set(normalize_string(cls.__name__) for cls in classes)}") + return ValueError(f"Could not resolve '{query}' among the choices " + f"{set(cls.__name__ for cls in classes)}") # Activation Resolver ######################################################### @@ -34,11 +36,28 @@ def swish(x: Tensor) -> Tensor: def activation_resolver(query: Union[Any, str] = 'relu', *args, **kwargs): + import torch + base_cls = torch.nn.Module + acts = [ act for act in vars(torch.nn.modules.activation).values() - if isinstance(act, type) and issubclass(act, torch.nn.Module) + if isinstance(act, type) and issubclass(act, base_cls) ] acts += [ swish, ] - return resolver(acts, query, *args, **kwargs) + return resolver(acts, query, base_cls, *args, **kwargs) + + +# Aggregation Resolver ######################################################## + + +def aggregation_resolver(query: Union[Any, str], *args, **kwargs): + import torch_geometric.nn.aggr as aggrs + base_cls = aggrs.Aggregation + + aggrs = [ + aggr for aggr in vars(aggrs).values() + if isinstance(aggr, type) and issubclass(aggr, base_cls) + ] + return resolver(aggrs, query, base_cls, *args, **kwargs)