From 1e5cc7077c7de0b5089d988abbdd67741c185ead Mon Sep 17 00:00:00 2001 From: Kamil Andrzejewski Date: Fri, 7 Apr 2023 13:21:24 +0100 Subject: [PATCH 1/2] Add dim_size arg for norm layers It can speedup runtime because: 1) We do not need to go through the batch dimension and look for max value. 2) We do not have to read tensor value which is placed on the device. Besides dim_size can be used if a user is using fixed size datasets. --- CHANGELOG.md | 1 + torch_geometric/nn/norm/graph_norm.py | 13 +++++++++---- torch_geometric/nn/norm/graph_size_norm.py | 10 ++++++++-- torch_geometric/nn/norm/instance_norm.py | 21 +++++++++++++-------- torch_geometric/nn/norm/layer_norm.py | 16 +++++++++++----- torch_geometric/nn/norm/pair_norm.py | 11 ++++++++--- torch_geometric/utils/nested.py | 2 +- 7 files changed, 51 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 439ba66cb419..9ef5ada2f02d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Added optional argument `dim_size` for `LayerNorm`, `GraphNorm`, `InstanceNorm`, `GraphSizeNorm` and `PairNorm` ([#7135](https://github.com/pyg-team/pytorch_geometric/pull/7135)) - Added support for `Data.num_edges` for native `torch.sparse.Tensor` adjacency matrices ([#7104](https://github.com/pyg-team/pytorch_geometric/pull/7104)) - Fixed crash of heterogeneous data loaders if node or edge types are missing ([#7060](https://github.com/pyg-team/pytorch_geometric/pull/7060), [#7087](https://github.com/pyg-team/pytorch_geometric/pull/7087)) - Accelerated attention-based `MultiAggregation` ([#7077](https://github.com/pyg-team/pytorch_geometric/pull/7077)) diff --git a/torch_geometric/nn/norm/graph_norm.py b/torch_geometric/nn/norm/graph_norm.py index c91bc01b8160..2ae1c40a593e 100644 --- a/torch_geometric/nn/norm/graph_norm.py +++ b/torch_geometric/nn/norm/graph_norm.py @@ -4,6 +4,7 @@ from torch import Tensor from torch_geometric.nn.inits import ones, zeros +from torch_geometric.typing import OptTensor from torch_geometric.utils import scatter @@ -44,22 +45,26 @@ def reset_parameters(self): zeros(self.bias) ones(self.mean_scale) - def forward(self, x: Tensor, batch: Optional[Tensor] = None) -> Tensor: + def forward(self, x: Tensor, batch: OptTensor = None, + dim_size: Optional[int] = None) -> Tensor: r""" Args: x (torch.Tensor): The source tensor. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. (default: :obj:`None`) + dim_size (int, optional): The number of examples :math:`B` in case + :obj:`batch` is given. (default: :obj:`None`) """ if batch is None: batch = x.new_zeros(x.size(0), dtype=torch.long) - batch_size = int(batch.max()) + 1 + if dim_size is None: + dim_size = int(batch.max()) + 1 - mean = scatter(x, batch, 0, batch_size, reduce='mean') + mean = scatter(x, batch, 0, dim_size, reduce='mean') out = x - mean.index_select(0, batch) * self.mean_scale - var = scatter(out.pow(2), batch, 0, batch_size, reduce='mean') + var = scatter(out.pow(2), batch, 0, dim_size, reduce='mean') std = (var + self.eps).sqrt().index_select(0, batch) return self.weight * out / std + self.bias diff --git a/torch_geometric/nn/norm/graph_size_norm.py b/torch_geometric/nn/norm/graph_size_norm.py index 6bf8899e5f22..5c8cb102e300 100644 --- a/torch_geometric/nn/norm/graph_size_norm.py +++ b/torch_geometric/nn/norm/graph_size_norm.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch import torch.nn as nn from torch import Tensor @@ -18,18 +20,22 @@ class GraphSizeNorm(nn.Module): def __init__(self): super().__init__() - def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor: + def forward(self, x: Tensor, batch: OptTensor = None, + dim_size: Optional[int] = None) -> Tensor: r""" Args: x (torch.Tensor): The source tensor. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. (default: :obj:`None`) + dim_size (int, optional): The number of examples :math:`B` in case + :obj:`batch` is given. (default: :obj:`None`) """ if batch is None: batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device) - inv_sqrt_deg = degree(batch, dtype=x.dtype).pow(-0.5) + inv_sqrt_deg = degree(batch, num_nodes=dim_size, + dtype=x.dtype).pow(-0.5) return x * inv_sqrt_deg.index_select(0, batch).view(-1, 1) def __repr__(self) -> str: diff --git a/torch_geometric/nn/norm/instance_norm.py b/torch_geometric/nn/norm/instance_norm.py index 8dde6bcd8917..ac3bfe12a366 100644 --- a/torch_geometric/nn/norm/instance_norm.py +++ b/torch_geometric/nn/norm/instance_norm.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch.nn.functional as F from torch import Tensor from torch.nn.modules.instancenorm import _InstanceNorm @@ -50,13 +52,16 @@ def reset_parameters(self): r"""Resets all learnable parameters of the module.""" super().reset_parameters() - def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor: + def forward(self, x: Tensor, batch: OptTensor = None, + dim_size: Optional[int] = None) -> Tensor: r""" Args: x (torch.Tensor): The source tensor. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. (default: :obj:`None`) + dim_size (int, optional): The number of examples :math:`B` in case + :obj:`batch` is given. (default: :obj:`None`) """ if batch is None: out = F.instance_norm( @@ -65,22 +70,22 @@ def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor: or not self.track_running_stats, self.momentum, self.eps) return out.squeeze(0).t() - batch_size = int(batch.max()) + 1 + if dim_size is None: + dim_size = int(batch.max()) + 1 mean = var = unbiased_var = x # Dummies. if self.training or not self.track_running_stats: - norm = degree(batch, batch_size, dtype=x.dtype).clamp_(min=1) + norm = degree(batch, dim_size, dtype=x.dtype).clamp_(min=1) norm = norm.view(-1, 1) unbiased_norm = (norm - 1).clamp_(min=1) - mean = scatter(x, batch, dim=0, dim_size=batch_size, + mean = scatter(x, batch, dim=0, dim_size=dim_size, reduce='sum') / norm x = x - mean.index_select(0, batch) - var = scatter(x * x, batch, dim=0, dim_size=batch_size, - reduce='sum') + var = scatter(x * x, batch, dim=0, dim_size=dim_size, reduce='sum') unbiased_var = var / unbiased_norm var = var / norm @@ -94,9 +99,9 @@ def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor: ) * self.running_var + momentum * unbiased_var.mean(0) else: if self.running_mean is not None: - mean = self.running_mean.view(1, -1).expand(batch_size, -1) + mean = self.running_mean.view(1, -1).expand(dim_size, -1) if self.running_var is not None: - var = self.running_var.view(1, -1).expand(batch_size, -1) + var = self.running_var.view(1, -1).expand(dim_size, -1) x = x - mean.index_select(0, batch) diff --git a/torch_geometric/nn/norm/layer_norm.py b/torch_geometric/nn/norm/layer_norm.py index 6dddcc6029d7..863cee65ffd9 100644 --- a/torch_geometric/nn/norm/layer_norm.py +++ b/torch_geometric/nn/norm/layer_norm.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch import torch.nn.functional as F from torch import Tensor @@ -62,13 +64,16 @@ def reset_parameters(self): ones(self.weight) zeros(self.bias) - def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor: + def forward(self, x: Tensor, batch: OptTensor = None, + dim_size: Optional[int] = None) -> Tensor: r""" Args: x (torch.Tensor): The source tensor. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. (default: :obj:`None`) + dim_size (int, optional): The number of examples :math:`B` in case + :obj:`batch` is given. (default: :obj:`None`) """ if self.mode == 'graph': if batch is None: @@ -76,17 +81,18 @@ def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor: out = x / (x.std(unbiased=False) + self.eps) else: - batch_size = int(batch.max()) + 1 + if dim_size is None: + dim_size = int(batch.max()) + 1 - norm = degree(batch, batch_size, dtype=x.dtype).clamp_(min=1) + norm = degree(batch, dim_size, dtype=x.dtype).clamp_(min=1) norm = norm.mul_(x.size(-1)).view(-1, 1) - mean = scatter(x, batch, dim=0, dim_size=batch_size, + mean = scatter(x, batch, dim=0, dim_size=dim_size, reduce='sum').sum(dim=-1, keepdim=True) / norm x = x - mean.index_select(0, batch) - var = scatter(x * x, batch, dim=0, dim_size=batch_size, + var = scatter(x * x, batch, dim=0, dim_size=dim_size, reduce='sum').sum(dim=-1, keepdim=True) var = var / norm diff --git a/torch_geometric/nn/norm/pair_norm.py b/torch_geometric/nn/norm/pair_norm.py index 49a32cc3b49d..5f5874eb6c98 100644 --- a/torch_geometric/nn/norm/pair_norm.py +++ b/torch_geometric/nn/norm/pair_norm.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from torch import Tensor @@ -36,13 +38,16 @@ def __init__(self, scale: float = 1., scale_individually: bool = False, self.scale_individually = scale_individually self.eps = eps - def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor: + def forward(self, x: Tensor, batch: OptTensor = None, + dim_size: Optional[int] = None) -> Tensor: r""" Args: x (torch.Tensor): The source tensor. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. (default: :obj:`None`) + dim_size (int, optional): The number of examples :math:`B` in case + :obj:`batch` is given. (default: :obj:`None`) """ scale = self.scale @@ -55,13 +60,13 @@ def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor: return scale * x / (self.eps + x.norm(2, -1, keepdim=True)) else: - mean = scatter(x, batch, dim=0, reduce='mean') + mean = scatter(x, batch, dim=0, dim_size=dim_size, reduce='mean') x = x - mean.index_select(0, batch) if not self.scale_individually: return scale * x / torch.sqrt(self.eps + scatter( x.pow(2).sum(-1, keepdim=True), batch, dim=0, - reduce='mean').index_select(0, batch)) + dim_size=dim_size, reduce='mean').index_select(0, batch)) else: return scale * x / (self.eps + x.norm(2, -1, keepdim=True)) diff --git a/torch_geometric/utils/nested.py b/torch_geometric/utils/nested.py index 03c3798b70c4..927e0d22e4e2 100644 --- a/torch_geometric/utils/nested.py +++ b/torch_geometric/utils/nested.py @@ -28,7 +28,7 @@ def to_nested_tensor( (default: :obj:`None`) ptr (torch.Tensor, optional): Alternative representation of :obj:`batch` in compressed format. (default: :obj:`None`) - batch_size (int, optional) The batch size :math:`B`. + batch_size (int, optional): The batch size :math:`B`. (default: :obj:`None`) """ if ptr is not None: From 2373600d70bbd1ea0dac46fdbe355cc3bf600208 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sun, 9 Apr 2023 22:13:32 +0000 Subject: [PATCH 2/2] update --- CHANGELOG.md | 2 +- torch_geometric/nn/norm/graph_norm.py | 15 ++++++++------- torch_geometric/nn/norm/graph_size_norm.py | 10 +++++----- torch_geometric/nn/norm/instance_norm.py | 21 +++++++++++---------- torch_geometric/nn/norm/layer_norm.py | 16 ++++++++-------- torch_geometric/nn/norm/pair_norm.py | 10 +++++----- 6 files changed, 38 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ef5ada2f02d..8e72a4a37d81 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- Added optional argument `dim_size` for `LayerNorm`, `GraphNorm`, `InstanceNorm`, `GraphSizeNorm` and `PairNorm` ([#7135](https://github.com/pyg-team/pytorch_geometric/pull/7135)) +- Added an optional `batch_size` argument to `LayerNorm`, `GraphNorm`, `InstanceNorm`, `GraphSizeNorm` and `PairNorm` ([#7135](https://github.com/pyg-team/pytorch_geometric/pull/7135)) - Added support for `Data.num_edges` for native `torch.sparse.Tensor` adjacency matrices ([#7104](https://github.com/pyg-team/pytorch_geometric/pull/7104)) - Fixed crash of heterogeneous data loaders if node or edge types are missing ([#7060](https://github.com/pyg-team/pytorch_geometric/pull/7060), [#7087](https://github.com/pyg-team/pytorch_geometric/pull/7087)) - Accelerated attention-based `MultiAggregation` ([#7077](https://github.com/pyg-team/pytorch_geometric/pull/7077)) diff --git a/torch_geometric/nn/norm/graph_norm.py b/torch_geometric/nn/norm/graph_norm.py index 2ae1c40a593e..02d37a860911 100644 --- a/torch_geometric/nn/norm/graph_norm.py +++ b/torch_geometric/nn/norm/graph_norm.py @@ -46,25 +46,26 @@ def reset_parameters(self): ones(self.mean_scale) def forward(self, x: Tensor, batch: OptTensor = None, - dim_size: Optional[int] = None) -> Tensor: + batch_size: Optional[int] = None) -> Tensor: r""" Args: x (torch.Tensor): The source tensor. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. (default: :obj:`None`) - dim_size (int, optional): The number of examples :math:`B` in case - :obj:`batch` is given. (default: :obj:`None`) + batch_size (int, optional): The number of examples :math:`B`. + Automatically calculated if not given. (default: :obj:`None`) """ if batch is None: batch = x.new_zeros(x.size(0), dtype=torch.long) + batch_size = 1 - if dim_size is None: - dim_size = int(batch.max()) + 1 + if batch_size is None: + batch_size = int(batch.max()) + 1 - mean = scatter(x, batch, 0, dim_size, reduce='mean') + mean = scatter(x, batch, 0, batch_size, reduce='mean') out = x - mean.index_select(0, batch) * self.mean_scale - var = scatter(out.pow(2), batch, 0, dim_size, reduce='mean') + var = scatter(out.pow(2), batch, 0, batch_size, reduce='mean') std = (var + self.eps).sqrt().index_select(0, batch) return self.weight * out / std + self.bias diff --git a/torch_geometric/nn/norm/graph_size_norm.py b/torch_geometric/nn/norm/graph_size_norm.py index 5c8cb102e300..243147d86912 100644 --- a/torch_geometric/nn/norm/graph_size_norm.py +++ b/torch_geometric/nn/norm/graph_size_norm.py @@ -21,21 +21,21 @@ def __init__(self): super().__init__() def forward(self, x: Tensor, batch: OptTensor = None, - dim_size: Optional[int] = None) -> Tensor: + batch_size: Optional[int] = None) -> Tensor: r""" Args: x (torch.Tensor): The source tensor. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. (default: :obj:`None`) - dim_size (int, optional): The number of examples :math:`B` in case - :obj:`batch` is given. (default: :obj:`None`) + batch_size (int, optional): The number of examples :math:`B`. + Automatically calculated if not given. (default: :obj:`None`) """ if batch is None: batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device) + batch_size = 1 - inv_sqrt_deg = degree(batch, num_nodes=dim_size, - dtype=x.dtype).pow(-0.5) + inv_sqrt_deg = degree(batch, batch_size, dtype=x.dtype).pow(-0.5) return x * inv_sqrt_deg.index_select(0, batch).view(-1, 1) def __repr__(self) -> str: diff --git a/torch_geometric/nn/norm/instance_norm.py b/torch_geometric/nn/norm/instance_norm.py index ac3bfe12a366..00a892ec820a 100644 --- a/torch_geometric/nn/norm/instance_norm.py +++ b/torch_geometric/nn/norm/instance_norm.py @@ -53,15 +53,15 @@ def reset_parameters(self): super().reset_parameters() def forward(self, x: Tensor, batch: OptTensor = None, - dim_size: Optional[int] = None) -> Tensor: + batch_size: Optional[int] = None) -> Tensor: r""" Args: x (torch.Tensor): The source tensor. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. (default: :obj:`None`) - dim_size (int, optional): The number of examples :math:`B` in case - :obj:`batch` is given. (default: :obj:`None`) + batch_size (int, optional): The number of examples :math:`B`. + Automatically calculated if not given. (default: :obj:`None`) """ if batch is None: out = F.instance_norm( @@ -70,22 +70,23 @@ def forward(self, x: Tensor, batch: OptTensor = None, or not self.track_running_stats, self.momentum, self.eps) return out.squeeze(0).t() - if dim_size is None: - dim_size = int(batch.max()) + 1 + if batch_size is None: + batch_size = int(batch.max()) + 1 mean = var = unbiased_var = x # Dummies. if self.training or not self.track_running_stats: - norm = degree(batch, dim_size, dtype=x.dtype).clamp_(min=1) + norm = degree(batch, batch_size, dtype=x.dtype).clamp_(min=1) norm = norm.view(-1, 1) unbiased_norm = (norm - 1).clamp_(min=1) - mean = scatter(x, batch, dim=0, dim_size=dim_size, + mean = scatter(x, batch, dim=0, dim_size=batch_size, reduce='sum') / norm x = x - mean.index_select(0, batch) - var = scatter(x * x, batch, dim=0, dim_size=dim_size, reduce='sum') + var = scatter(x * x, batch, dim=0, dim_size=batch_size, + reduce='sum') unbiased_var = var / unbiased_norm var = var / norm @@ -99,9 +100,9 @@ def forward(self, x: Tensor, batch: OptTensor = None, ) * self.running_var + momentum * unbiased_var.mean(0) else: if self.running_mean is not None: - mean = self.running_mean.view(1, -1).expand(dim_size, -1) + mean = self.running_mean.view(1, -1).expand(batch_size, -1) if self.running_var is not None: - var = self.running_var.view(1, -1).expand(dim_size, -1) + var = self.running_var.view(1, -1).expand(batch_size, -1) x = x - mean.index_select(0, batch) diff --git a/torch_geometric/nn/norm/layer_norm.py b/torch_geometric/nn/norm/layer_norm.py index 863cee65ffd9..c40d3cc9cfd9 100644 --- a/torch_geometric/nn/norm/layer_norm.py +++ b/torch_geometric/nn/norm/layer_norm.py @@ -65,15 +65,15 @@ def reset_parameters(self): zeros(self.bias) def forward(self, x: Tensor, batch: OptTensor = None, - dim_size: Optional[int] = None) -> Tensor: + batch_size: Optional[int] = None) -> Tensor: r""" Args: x (torch.Tensor): The source tensor. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. (default: :obj:`None`) - dim_size (int, optional): The number of examples :math:`B` in case - :obj:`batch` is given. (default: :obj:`None`) + batch_size (int, optional): The number of examples :math:`B`. + Automatically calculated if not given. (default: :obj:`None`) """ if self.mode == 'graph': if batch is None: @@ -81,18 +81,18 @@ def forward(self, x: Tensor, batch: OptTensor = None, out = x / (x.std(unbiased=False) + self.eps) else: - if dim_size is None: - dim_size = int(batch.max()) + 1 + if batch_size is None: + batch_size = int(batch.max()) + 1 - norm = degree(batch, dim_size, dtype=x.dtype).clamp_(min=1) + norm = degree(batch, batch_size, dtype=x.dtype).clamp_(min=1) norm = norm.mul_(x.size(-1)).view(-1, 1) - mean = scatter(x, batch, dim=0, dim_size=dim_size, + mean = scatter(x, batch, dim=0, dim_size=batch_size, reduce='sum').sum(dim=-1, keepdim=True) / norm x = x - mean.index_select(0, batch) - var = scatter(x * x, batch, dim=0, dim_size=dim_size, + var = scatter(x * x, batch, dim=0, dim_size=batch_size, reduce='sum').sum(dim=-1, keepdim=True) var = var / norm diff --git a/torch_geometric/nn/norm/pair_norm.py b/torch_geometric/nn/norm/pair_norm.py index 5f5874eb6c98..c548ddc6c242 100644 --- a/torch_geometric/nn/norm/pair_norm.py +++ b/torch_geometric/nn/norm/pair_norm.py @@ -39,15 +39,15 @@ def __init__(self, scale: float = 1., scale_individually: bool = False, self.eps = eps def forward(self, x: Tensor, batch: OptTensor = None, - dim_size: Optional[int] = None) -> Tensor: + batch_size: Optional[int] = None) -> Tensor: r""" Args: x (torch.Tensor): The source tensor. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. (default: :obj:`None`) - dim_size (int, optional): The number of examples :math:`B` in case - :obj:`batch` is given. (default: :obj:`None`) + batch_size (int, optional): The number of examples :math:`B`. + Automatically calculated if not given. (default: :obj:`None`) """ scale = self.scale @@ -60,13 +60,13 @@ def forward(self, x: Tensor, batch: OptTensor = None, return scale * x / (self.eps + x.norm(2, -1, keepdim=True)) else: - mean = scatter(x, batch, dim=0, dim_size=dim_size, reduce='mean') + mean = scatter(x, batch, dim=0, dim_size=batch_size, reduce='mean') x = x - mean.index_select(0, batch) if not self.scale_individually: return scale * x / torch.sqrt(self.eps + scatter( x.pow(2).sum(-1, keepdim=True), batch, dim=0, - dim_size=dim_size, reduce='mean').index_select(0, batch)) + dim_size=batch_size, reduce='mean').index_select(0, batch)) else: return scale * x / (self.eps + x.norm(2, -1, keepdim=True))