Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add node-wise normalization mode in LayerNorm #4944

Merged
merged 14 commits into from
Jul 8, 2022
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added node-wise normalization mode in `LayerNorm` ([#4944](https://github.com/pyg-team/pytorch_geometric/pull/4944))
- Added support for `normalization_resolver` ([#4926](https://github.com/pyg-team/pytorch_geometric/pull/4926))
- Added notebook tutorial for `torch_geometric.nn.aggr` package to documentation ([#4927](https://github.com/pyg-team/pytorch_geometric/pull/4927))
- Added support for `follow_batch` for lists or dictionaries of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837))
Expand Down
7 changes: 4 additions & 3 deletions test/nn/norm/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@


@pytest.mark.parametrize('affine', [True, False])
def test_layer_norm(affine):
@pytest.mark.parametrize('mode', ['graph', 'node'])
def test_layer_norm(affine, mode):
x = torch.randn(100, 16)
batch = torch.zeros(100, dtype=torch.long)

norm = LayerNorm(16, affine=affine)
assert norm.__repr__() == 'LayerNorm(16)'
norm = LayerNorm(16, affine=affine, mode=mode)
assert norm.__repr__() == f'LayerNorm(16, mode={mode})'

if is_full_test():
torch.jit.script(norm)
Expand Down
52 changes: 34 additions & 18 deletions torch_geometric/nn/norm/layer_norm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter
from torch_scatter import scatter
Expand Down Expand Up @@ -29,12 +30,18 @@ class LayerNorm(torch.nn.Module):
affine (bool, optional): If set to :obj:`True`, this module has
learnable affine parameters :math:`\gamma` and :math:`\beta`.
(default: :obj:`True`)
mode (str, optinal): The normalization mode to use in layer norm
lightaime marked this conversation as resolved.
Show resolved Hide resolved
(`"graph"` or `"node"`). If `"graph"` is used, each graph will
lightaime marked this conversation as resolved.
Show resolved Hide resolved
be considered as an element to be normalized. If `"node"` is used,
lightaime marked this conversation as resolved.
Show resolved Hide resolved
each node will be considered as an element to be normalized.
(default: :obj:`"graph"`)
"""
def __init__(self, in_channels, eps=1e-5, affine=True):
def __init__(self, in_channels, eps=1e-5, affine=True, mode='graph'):
lightaime marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()

self.in_channels = in_channels
self.eps = eps
self.mode = mode

if affine:
self.weight = Parameter(torch.Tensor(in_channels))
Expand All @@ -51,31 +58,40 @@ def reset_parameters(self):

def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor:
""""""
if batch is None:
x = x - x.mean()
out = x / (x.std(unbiased=False) + self.eps)
if self.mode == 'graph':
if batch is None:
x = x - x.mean()
out = x / (x.std(unbiased=False) + self.eps)

else:
batch_size = int(batch.max()) + 1
else:
batch_size = int(batch.max()) + 1

norm = degree(batch, batch_size, dtype=x.dtype).clamp_(min=1)
norm = norm.mul_(x.size(-1)).view(-1, 1)

norm = degree(batch, batch_size, dtype=x.dtype).clamp_(min=1)
norm = norm.mul_(x.size(-1)).view(-1, 1)
mean = scatter(x, batch, dim=0, dim_size=batch_size,
reduce='add').sum(dim=-1, keepdim=True) / norm

mean = scatter(x, batch, dim=0, dim_size=batch_size,
reduce='add').sum(dim=-1, keepdim=True) / norm
x = x - mean.index_select(0, batch)

x = x - mean.index_select(0, batch)
var = scatter(x * x, batch, dim=0, dim_size=batch_size,
reduce='add').sum(dim=-1, keepdim=True)
var = var / norm

var = scatter(x * x, batch, dim=0, dim_size=batch_size,
reduce='add').sum(dim=-1, keepdim=True)
var = var / norm
out = x / (var + self.eps).sqrt().index_select(0, batch)

out = x / (var + self.eps).sqrt().index_select(0, batch)
if self.weight is not None and self.bias is not None:
out = out * self.weight + self.bias
lightaime marked this conversation as resolved.
Show resolved Hide resolved

if self.weight is not None and self.bias is not None:
out = out * self.weight + self.bias
elif self.mode == 'node':
lightaime marked this conversation as resolved.
Show resolved Hide resolved
out = F.layer_norm(x, (self.in_channels, ), self.weight, self.bias,
lightaime marked this conversation as resolved.
Show resolved Hide resolved
self.eps)

else:
lightaime marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"Unknow normalization mode: {self.mode}")

return out

def __repr__(self):
return f'{self.__class__.__name__}({self.in_channels})'
return (f'{self.__class__.__name__}({self.in_channels}, '
f'mode={self.mode})')