Skip to content

Commit

Permalink
Add node-wise normalization mode in LayerNorm (#4944)
Browse files Browse the repository at this point in the history
* Add node-wise normalization in LayerNorm

* changelog

* update

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>

* update

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>

* update

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>

* update

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>

* update

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>

* update

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

Co-authored-by: Guohao Li <lighaime@gmail.com>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Jul 8, 2022
1 parent db5e6d9 commit d220afe
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 22 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added node-wise normalization mode in `LayerNorm` ([#4944](https://github.com/pyg-team/pytorch_geometric/pull/4944))
- Added support for `normalization_resolver` ([#4926](https://github.com/pyg-team/pytorch_geometric/pull/4926))
- Added notebook tutorial for `torch_geometric.nn.aggr` package to documentation ([#4927](https://github.com/pyg-team/pytorch_geometric/pull/4927))
- Added support for `follow_batch` for lists or dictionaries of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837))
Expand Down
7 changes: 4 additions & 3 deletions test/nn/norm/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@


@pytest.mark.parametrize('affine', [True, False])
def test_layer_norm(affine):
@pytest.mark.parametrize('mode', ['graph', 'node'])
def test_layer_norm(affine, mode):
x = torch.randn(100, 16)
batch = torch.zeros(100, dtype=torch.long)

norm = LayerNorm(16, affine=affine)
assert norm.__repr__() == 'LayerNorm(16)'
norm = LayerNorm(16, affine=affine, mode=mode)
assert norm.__repr__() == f'LayerNorm(16, mode={mode})'

if is_full_test():
torch.jit.script(norm)
Expand Down
54 changes: 35 additions & 19 deletions torch_geometric/nn/norm/layer_norm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter
from torch_scatter import scatter
Expand Down Expand Up @@ -29,12 +30,19 @@ class LayerNorm(torch.nn.Module):
affine (bool, optional): If set to :obj:`True`, this module has
learnable affine parameters :math:`\gamma` and :math:`\beta`.
(default: :obj:`True`)
mode (str, optinal): The normalization mode to use for layer
normalization. (:obj:`"graph"` or :obj:`"node"`). If :obj:`"graph"`
is used, each graph will be considered as an element to be
normalized. If `"node"` is used, each node will be considered as
an element to be normalized. (default: :obj:`"graph"`)
"""
def __init__(self, in_channels, eps=1e-5, affine=True):
def __init__(self, in_channels: int, eps: float = 1e-5,
affine: bool = True, mode: str = 'graph'):
super().__init__()

self.in_channels = in_channels
self.eps = eps
self.mode = mode

if affine:
self.weight = Parameter(torch.Tensor(in_channels))
Expand All @@ -51,31 +59,39 @@ def reset_parameters(self):

def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor:
""""""
if batch is None:
x = x - x.mean()
out = x / (x.std(unbiased=False) + self.eps)
if self.mode == 'graph':
if batch is None:
x = x - x.mean()
out = x / (x.std(unbiased=False) + self.eps)

else:
batch_size = int(batch.max()) + 1
else:
batch_size = int(batch.max()) + 1

norm = degree(batch, batch_size, dtype=x.dtype).clamp_(min=1)
norm = norm.mul_(x.size(-1)).view(-1, 1)

mean = scatter(x, batch, dim=0, dim_size=batch_size,
reduce='add').sum(dim=-1, keepdim=True) / norm

norm = degree(batch, batch_size, dtype=x.dtype).clamp_(min=1)
norm = norm.mul_(x.size(-1)).view(-1, 1)
x = x - mean.index_select(0, batch)

mean = scatter(x, batch, dim=0, dim_size=batch_size,
reduce='add').sum(dim=-1, keepdim=True) / norm
var = scatter(x * x, batch, dim=0, dim_size=batch_size,
reduce='add').sum(dim=-1, keepdim=True)
var = var / norm

x = x - mean.index_select(0, batch)
out = x / (var + self.eps).sqrt().index_select(0, batch)

var = scatter(x * x, batch, dim=0, dim_size=batch_size,
reduce='add').sum(dim=-1, keepdim=True)
var = var / norm
if self.weight is not None and self.bias is not None:
out = out * self.weight + self.bias

out = x / (var + self.eps).sqrt().index_select(0, batch)
return out

if self.weight is not None and self.bias is not None:
out = out * self.weight + self.bias
if self.mode == 'node':
return F.layer_norm(x, (self.in_channels, ), self.weight,
self.bias, self.eps)

return out
raise ValueError(f"Unknow normalization mode: {self.mode}")

def __repr__(self):
return f'{self.__class__.__name__}({self.in_channels})'
return (f'{self.__class__.__name__}({self.in_channels}, '
f'mode={self.mode})')

0 comments on commit d220afe

Please sign in to comment.