diff --git a/CHANGELOG.md b/CHANGELOG.md index e03d3bd0a2be..15090708aa02 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.0.5] - 2022-MM-DD ### Added +- Added node-wise normalization mode in `LayerNorm` ([#4944](https://github.com/pyg-team/pytorch_geometric/pull/4944)) - Added support for `normalization_resolver` ([#4926](https://github.com/pyg-team/pytorch_geometric/pull/4926)) - Added notebook tutorial for `torch_geometric.nn.aggr` package to documentation ([#4927](https://github.com/pyg-team/pytorch_geometric/pull/4927)) - Added support for `follow_batch` for lists or dictionaries of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837)) diff --git a/test/nn/norm/test_layer_norm.py b/test/nn/norm/test_layer_norm.py index e75e2889fe8b..79a94498c8d1 100644 --- a/test/nn/norm/test_layer_norm.py +++ b/test/nn/norm/test_layer_norm.py @@ -6,12 +6,13 @@ @pytest.mark.parametrize('affine', [True, False]) -def test_layer_norm(affine): +@pytest.mark.parametrize('mode', ['graph', 'node']) +def test_layer_norm(affine, mode): x = torch.randn(100, 16) batch = torch.zeros(100, dtype=torch.long) - norm = LayerNorm(16, affine=affine) - assert norm.__repr__() == 'LayerNorm(16)' + norm = LayerNorm(16, affine=affine, mode=mode) + assert norm.__repr__() == f'LayerNorm(16, mode={mode})' if is_full_test(): torch.jit.script(norm) diff --git a/torch_geometric/nn/norm/layer_norm.py b/torch_geometric/nn/norm/layer_norm.py index 81c607e031f3..61c6677d9e1d 100644 --- a/torch_geometric/nn/norm/layer_norm.py +++ b/torch_geometric/nn/norm/layer_norm.py @@ -1,4 +1,5 @@ import torch +import torch.nn.functional as F from torch import Tensor from torch.nn import Parameter from torch_scatter import scatter @@ -29,12 +30,19 @@ class LayerNorm(torch.nn.Module): affine (bool, optional): If set to :obj:`True`, this module has learnable affine parameters :math:`\gamma` and :math:`\beta`. (default: :obj:`True`) + mode (str, optinal): The normalization mode to use for layer + normalization. (:obj:`"graph"` or :obj:`"node"`). If :obj:`"graph"` + is used, each graph will be considered as an element to be + normalized. If `"node"` is used, each node will be considered as + an element to be normalized. (default: :obj:`"graph"`) """ - def __init__(self, in_channels, eps=1e-5, affine=True): + def __init__(self, in_channels: int, eps: float = 1e-5, + affine: bool = True, mode: str = 'graph'): super().__init__() self.in_channels = in_channels self.eps = eps + self.mode = mode if affine: self.weight = Parameter(torch.Tensor(in_channels)) @@ -51,31 +59,39 @@ def reset_parameters(self): def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor: """""" - if batch is None: - x = x - x.mean() - out = x / (x.std(unbiased=False) + self.eps) + if self.mode == 'graph': + if batch is None: + x = x - x.mean() + out = x / (x.std(unbiased=False) + self.eps) - else: - batch_size = int(batch.max()) + 1 + else: + batch_size = int(batch.max()) + 1 + + norm = degree(batch, batch_size, dtype=x.dtype).clamp_(min=1) + norm = norm.mul_(x.size(-1)).view(-1, 1) + + mean = scatter(x, batch, dim=0, dim_size=batch_size, + reduce='add').sum(dim=-1, keepdim=True) / norm - norm = degree(batch, batch_size, dtype=x.dtype).clamp_(min=1) - norm = norm.mul_(x.size(-1)).view(-1, 1) + x = x - mean.index_select(0, batch) - mean = scatter(x, batch, dim=0, dim_size=batch_size, - reduce='add').sum(dim=-1, keepdim=True) / norm + var = scatter(x * x, batch, dim=0, dim_size=batch_size, + reduce='add').sum(dim=-1, keepdim=True) + var = var / norm - x = x - mean.index_select(0, batch) + out = x / (var + self.eps).sqrt().index_select(0, batch) - var = scatter(x * x, batch, dim=0, dim_size=batch_size, - reduce='add').sum(dim=-1, keepdim=True) - var = var / norm + if self.weight is not None and self.bias is not None: + out = out * self.weight + self.bias - out = x / (var + self.eps).sqrt().index_select(0, batch) + return out - if self.weight is not None and self.bias is not None: - out = out * self.weight + self.bias + if self.mode == 'node': + return F.layer_norm(x, (self.in_channels, ), self.weight, + self.bias, self.eps) - return out + raise ValueError(f"Unknow normalization mode: {self.mode}") def __repr__(self): - return f'{self.__class__.__name__}({self.in_channels})' + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'mode={self.mode})')