Skip to content

Commit

Permalink
Add dim_size arg for norm layers
Browse files Browse the repository at this point in the history
It can speedup runtime because:
1) We do not need to go through the batch dimension and look for max value.
2) We do not have to read tensor value which is placed on the device.

Besides dim_size can be used if a user is using fixed size datasets.
  • Loading branch information
Kamil Andrzejewski committed Apr 7, 2023
1 parent 87744e2 commit 906719a
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 22 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Added optional argument `dim_size` for `LayerNorm`, `GraphNorm`, `InstanceNorm`, `GraphSizeNorm` and `PairNorm` ([#7135](https://github.com/pyg-team/pytorch_geometric/pull/7135))
- Added support for `Data.num_edges` for native `torch.sparse.Tensor` adjacency matrices ([#7104](https://github.com/pyg-team/pytorch_geometric/pull/7104))
- Fixed crash of heterogeneous data loaders if node or edge types are missing ([#7060](https://github.com/pyg-team/pytorch_geometric/pull/7060), [#7087](https://github.com/pyg-team/pytorch_geometric/pull/7087))
- Accelerated attention-based `MultiAggregation` ([#7077](https://github.com/pyg-team/pytorch_geometric/pull/7077))
Expand Down
13 changes: 9 additions & 4 deletions torch_geometric/nn/norm/graph_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch import Tensor

from torch_geometric.typing import OptTensor
from torch_geometric.nn.inits import ones, zeros
from torch_geometric.utils import scatter

Expand Down Expand Up @@ -44,22 +45,26 @@ def reset_parameters(self):
zeros(self.bias)
ones(self.mean_scale)

def forward(self, x: Tensor, batch: Optional[Tensor] = None) -> Tensor:
def forward(self, x: Tensor, batch: OptTensor = None,
dim_size: Optional[int] = None) -> Tensor:
r"""
Args:
x (torch.Tensor): The source tensor.
batch (torch.Tensor, optional): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
each element to a specific example. (default: :obj:`None`)
dim_size (int, optional): The number of examples :math:`B` in case
:obj:`batch` is given. (default: :obj:`None`)
"""
if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.long)

batch_size = int(batch.max()) + 1
if dim_size is None:
dim_size = int(batch.max()) + 1

mean = scatter(x, batch, 0, batch_size, reduce='mean')
mean = scatter(x, batch, 0, dim_size, reduce='mean')
out = x - mean.index_select(0, batch) * self.mean_scale
var = scatter(out.pow(2), batch, 0, batch_size, reduce='mean')
var = scatter(out.pow(2), batch, 0, dim_size, reduce='mean')
std = (var + self.eps).sqrt().index_select(0, batch)
return self.weight * out / std + self.bias

Expand Down
9 changes: 7 additions & 2 deletions torch_geometric/nn/norm/graph_size_norm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch
import torch.nn as nn
from torch import Tensor
Expand All @@ -18,18 +20,21 @@ class GraphSizeNorm(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor:
def forward(self, x: Tensor, batch: OptTensor = None,
dim_size: Optional[int] = None) -> Tensor:
r"""
Args:
x (torch.Tensor): The source tensor.
batch (torch.Tensor, optional): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
each element to a specific example. (default: :obj:`None`)
dim_size (int, optional): The number of examples :math:`B` in case
:obj:`batch` is given. (default: :obj:`None`)
"""
if batch is None:
batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)

inv_sqrt_deg = degree(batch, dtype=x.dtype).pow(-0.5)
inv_sqrt_deg = degree(batch, num_nodes=dim_size, dtype=x.dtype).pow(-0.5)
return x * inv_sqrt_deg.index_select(0, batch).view(-1, 1)

def __repr__(self) -> str:
Expand Down
20 changes: 13 additions & 7 deletions torch_geometric/nn/norm/instance_norm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules.instancenorm import _InstanceNorm
Expand Down Expand Up @@ -50,13 +52,16 @@ def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
super().reset_parameters()

def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor:
def forward(self, x: Tensor, batch: OptTensor = None,
dim_size: Optional[int] = None) -> Tensor:
r"""
Args:
x (torch.Tensor): The source tensor.
batch (torch.Tensor, optional): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
each element to a specific example. (default: :obj:`None`)
dim_size (int, optional): The number of examples :math:`B` in case
:obj:`batch` is given. (default: :obj:`None`)
"""
if batch is None:
out = F.instance_norm(
Expand All @@ -65,21 +70,22 @@ def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor:
or not self.track_running_stats, self.momentum, self.eps)
return out.squeeze(0).t()

batch_size = int(batch.max()) + 1
if dim_size is None:
dim_size = int(batch.max()) + 1

mean = var = unbiased_var = x # Dummies.

if self.training or not self.track_running_stats:
norm = degree(batch, batch_size, dtype=x.dtype).clamp_(min=1)
norm = degree(batch, dim_size, dtype=x.dtype).clamp_(min=1)
norm = norm.view(-1, 1)
unbiased_norm = (norm - 1).clamp_(min=1)

mean = scatter(x, batch, dim=0, dim_size=batch_size,
mean = scatter(x, batch, dim=0, dim_size=dim_size,
reduce='sum') / norm

x = x - mean.index_select(0, batch)

var = scatter(x * x, batch, dim=0, dim_size=batch_size,
var = scatter(x * x, batch, dim=0, dim_size=dim_size,
reduce='sum')
unbiased_var = var / unbiased_norm
var = var / norm
Expand All @@ -94,9 +100,9 @@ def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor:
) * self.running_var + momentum * unbiased_var.mean(0)
else:
if self.running_mean is not None:
mean = self.running_mean.view(1, -1).expand(batch_size, -1)
mean = self.running_mean.view(1, -1).expand(dim_size, -1)
if self.running_var is not None:
var = self.running_var.view(1, -1).expand(batch_size, -1)
var = self.running_var.view(1, -1).expand(dim_size, -1)

x = x - mean.index_select(0, batch)

Expand Down
16 changes: 11 additions & 5 deletions torch_geometric/nn/norm/layer_norm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch
import torch.nn.functional as F
from torch import Tensor
Expand Down Expand Up @@ -62,31 +64,35 @@ def reset_parameters(self):
ones(self.weight)
zeros(self.bias)

def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor:
def forward(self, x: Tensor, batch: OptTensor = None,
dim_size: Optional[int] = None) -> Tensor:
r"""
Args:
x (torch.Tensor): The source tensor.
batch (torch.Tensor, optional): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
each element to a specific example. (default: :obj:`None`)
dim_size (int, optional): The number of examples :math:`B` in case
:obj:`batch` is given. (default: :obj:`None`)
"""
if self.mode == 'graph':
if batch is None:
x = x - x.mean()
out = x / (x.std(unbiased=False) + self.eps)

else:
batch_size = int(batch.max()) + 1
if dim_size is None:
dim_size = int(batch.max()) + 1

norm = degree(batch, batch_size, dtype=x.dtype).clamp_(min=1)
norm = degree(batch, dim_size, dtype=x.dtype).clamp_(min=1)
norm = norm.mul_(x.size(-1)).view(-1, 1)

mean = scatter(x, batch, dim=0, dim_size=batch_size,
mean = scatter(x, batch, dim=0, dim_size=dim_size,
reduce='sum').sum(dim=-1, keepdim=True) / norm

x = x - mean.index_select(0, batch)

var = scatter(x * x, batch, dim=0, dim_size=batch_size,
var = scatter(x * x, batch, dim=0, dim_size=dim_size,
reduce='sum').sum(dim=-1, keepdim=True)
var = var / norm

Expand Down
11 changes: 8 additions & 3 deletions torch_geometric/nn/norm/pair_norm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch
from torch import Tensor

Expand Down Expand Up @@ -36,13 +38,16 @@ def __init__(self, scale: float = 1., scale_individually: bool = False,
self.scale_individually = scale_individually
self.eps = eps

def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor:
def forward(self, x: Tensor, batch: OptTensor = None,
dim_size: Optional[int] = None) -> Tensor:
r"""
Args:
x (torch.Tensor): The source tensor.
batch (torch.Tensor, optional): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
each element to a specific example. (default: :obj:`None`)
dim_size (int, optional): The number of examples :math:`B` in case
:obj:`batch` is given. (default: :obj:`None`)
"""
scale = self.scale

Expand All @@ -55,13 +60,13 @@ def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor:
return scale * x / (self.eps + x.norm(2, -1, keepdim=True))

else:
mean = scatter(x, batch, dim=0, reduce='mean')
mean = scatter(x, batch, dim=0, dim_size=dim_size, reduce='mean')
x = x - mean.index_select(0, batch)

if not self.scale_individually:
return scale * x / torch.sqrt(self.eps + scatter(
x.pow(2).sum(-1, keepdim=True), batch, dim=0,
reduce='mean').index_select(0, batch))
dim_size=dim_size, reduce='mean').index_select(0, batch))
else:
return scale * x / (self.eps + x.norm(2, -1, keepdim=True))

Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/utils/nested.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def to_nested_tensor(
(default: :obj:`None`)
ptr (torch.Tensor, optional): Alternative representation of
:obj:`batch` in compressed format. (default: :obj:`None`)
batch_size (int, optional) The batch size :math:`B`.
batch_size (int, optional): The batch size :math:`B`.
(default: :obj:`None`)
"""
if ptr is not None:
Expand Down

0 comments on commit 906719a

Please sign in to comment.