Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FusedAggregation of simple scatter reductions #6036

Merged
merged 10 commits into from
Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.2.0] - 2022-MM-DD
### Added
- Added `FusedAggregation` of simple scatter reductions ([#6036](https://github.com/pyg-team/pytorch_geometric/pull/6036))
- Added `HeteroData` support for `to_captum_model` and added `to_captum_input` ([#5934](https://github.com/pyg-team/pytorch_geometric/pull/5934))
- Added `HeteroData` support in `RandomNodeLoader` ([#6007](https://github.com/pyg-team/pytorch_geometric/pull/6007))
- Added bipartite `GraphSAGE` example ([#5834](https://github.com/pyg-team/pytorch_geometric/pull/5834))
Expand Down
67 changes: 67 additions & 0 deletions test/nn/aggr/test_fused.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import pytest
import torch

from torch_geometric.nn.aggr.fused import FusedAggregation
from torch_geometric.nn.resolver import aggregation_resolver


@pytest.mark.parametrize('aggrs', [
['sum', 'mean', 'min', 'max', 'mul', 'var', 'std'],
['sum', 'min', 'max', 'mul', 'var', 'std'],
['min', 'max', 'mul', 'var', 'std'],
['mean', 'min', 'max', 'mul', 'var', 'std'],
['sum', 'min', 'max', 'mul', 'std'],
['mean', 'min', 'max', 'mul', 'std'],
['min', 'max', 'mul', 'std'],
])
def test_fused_aggregation(aggrs):
aggrs = [aggregation_resolver(aggr) for aggr in aggrs]

x = torch.randn(6, 1)
y = x.clone()
index = torch.tensor([0, 0, 1, 1, 1, 3])

x.requires_grad_(True)
y.requires_grad_(True)

aggr = FusedAggregation(aggrs)
assert str(aggr) == 'FusedAggregation()'
out = aggr(x, index)

expected = torch.cat([aggr(y, index) for aggr in aggrs], dim=-1)
assert torch.allclose(out, expected)

out.mean().backward()
assert x.grad is not None
expected.mean().backward()
assert y.grad is not None
assert torch.allclose(x.grad, y.grad)


if __name__ == '__main__':
import time

x = torch.randn(50000, 64, device='cuda')
index = torch.randint(1000, (x.size(0), ), device='cuda')

aggrs = ['sum', 'mean', 'max', 'std']
aggrs = [aggregation_resolver(aggr) for aggr in aggrs]
fused_aggr = FusedAggregation(aggrs)

num_warmups, num_steps = (500, 1000)

for i in range(num_warmups + num_steps):
if i == num_warmups:
torch.cuda.synchronize()
t = time.perf_counter()
torch.cat([aggr(x, index, dim_size=1000) for aggr in aggrs], dim=-1)
torch.cuda.synchronize()
print(f'Vanilla implementation: {time.perf_counter() - t:.4f} seconds')

for i in range(num_warmups + num_steps):
if i == num_warmups:
torch.cuda.synchronize()
t = time.perf_counter()
fused_aggr(x, index, dim_size=1000)
torch.cuda.synchronize()
print(f'Fused implementation: {time.perf_counter() - t:.4f} seconds')
2 changes: 1 addition & 1 deletion torch_geometric/nn/aggr/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
var = self.var_aggr(x, index, ptr, dim_size, dim)
return torch.sqrt(var.relu() + 1e-5)
return (var.relu() + 1e-5).sqrt()


class SoftmaxAggregation(Aggregation):
Expand Down
293 changes: 293 additions & 0 deletions torch_geometric/nn/aggr/fused.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
from typing import Any, List, Optional, Tuple, Union

import torch
from torch import Tensor

from torch_geometric.nn import (
Aggregation,
MaxAggregation,
MeanAggregation,
MinAggregation,
MulAggregation,
StdAggregation,
SumAggregation,
VarAggregation,
)
from torch_geometric.nn.resolver import aggregation_resolver


class FusedAggregation(Aggregation):
r"""Helper class to fuse computation of multiple aggregations together.
Used internally in :class:`~torch_geometric.nn.aggr.MultiAggregation` to
speed-up computation.
Currently, the following optimizations are performed:

* :class:`MeanAggregation` will share the output with
:class:`SumAggregation` in case it is present as well.

* :class:`VarAggregation` will share the output with either
:class:`MeanAggregation` or :class:`SumAggregation` in case one of them
is present as well.

* :class:`StdAggregation` will share the output with either
:class:`VarAggregation`, :class:`MeanAggregation` or
:class:`SumAggregation` in case one of them is present as well.

In addition, temporary values such as the count per group index or the
mask for invalid rows are shared as well.

Benchmarking results on PyTorch 1.12 (summed over 1000 runs):

+------------------------------+---------+---------+
| Aggregators | Vanilla | Fusion |
+==============================+=========+=========+
| :obj:`[sum, mean]` | 0.4019s | 0.1666s |
+------------------------------+---------+---------+
| :obj:`[sum, mean, min, max]` | 0.7841s | 0.4223s |
+------------------------------+---------+---------+
| :obj:`[sum, mean, var]` | 0.9711s | 0.3614s |
+------------------------------+---------+---------+
| :obj:`[sum, mean, var, std]` | 1.5994s | 0.3722s |
+------------------------------+---------+---------+

Args:
aggrs (list): The list of aggregation schemes to use.
"""
# We can fuse all aggregations together that rely on `scatter` directives.
FUSABLE_AGGRS = {
SumAggregation,
MeanAggregation,
MinAggregation,
MaxAggregation,
MulAggregation,
VarAggregation,
StdAggregation,
}

# All aggregations that rely on computing the degree of indices.
DEGREE_BASED_AGGRS = {
MeanAggregation,
VarAggregation,
StdAggregation,
}

# All aggregations that require manual masking for invalid rows:
MASK_REQUIRED_AGGRS = {
MinAggregation,
MaxAggregation,
MulAggregation,
}

# Map aggregations to `reduce` options in `scatter` directives.
REDUCE = {
SumAggregation: 'sum',
MeanAggregation: 'sum',
MinAggregation: 'amin',
MaxAggregation: 'amax',
MulAggregation: 'prod',
VarAggregation: 'pow_sum',
StdAggregation: 'pow_sum',
}

def __init__(self, aggrs: List[Union[Aggregation, str]]):
super().__init__()

if not isinstance(aggrs, (list, tuple)):
raise ValueError(f"'aggrs' of '{self.__class__.__name__}' should "
f"be a list or tuple (got '{type(aggrs)}').")

if len(aggrs) == 0:
raise ValueError(f"'aggrs' of '{self.__class__.__name__}' should "
f"not be empty.")

aggrs = [aggregation_resolver(aggr) for aggr in aggrs]
self.aggr_cls = [aggr.__class__ for aggr in aggrs]
self.aggr_index = {cls: i for i, cls in enumerate(self.aggr_cls)}

for cls in self.aggr_cls:
if cls not in self.FUSABLE_AGGRS:
raise ValueError(f"Received aggregation '{cls.__name__}' in "
f"'{self.__class__.__name__}' which is not "
f"fusable")

# Check whether we need to compute degree information:
self.need_degree = False
for cls in self.aggr_cls:
if cls in self.DEGREE_BASED_AGGRS:
self.need_degree = True

# Check whether we need to compute mask information:
self.requires_mask = False
for cls in self.aggr_cls:
if cls in self.MASK_REQUIRED_AGGRS:
self.requires_mask = True

# Determine which reduction to use for each aggregator:
# An entry of `None` means that this operator re-uses intermediate
# outputs from other aggregators.
self.reduce_ops: List[Optional[str]] = []
# Determine which `(Aggregator, index)` to use as intermediate output:
self.lookup_ops: List[Optional[Tuple[Any, int]]] = []

for cls in self.aggr_cls:
if cls == MeanAggregation:
# Directly use output of `SumAggregation`:
if SumAggregation in self.aggr_index:
self.reduce_ops.append(None)
self.lookup_ops.append(
(SumAggregation, self.aggr_index[SumAggregation]))
else:
self.reduce_ops.append(self.REDUCE[cls])
self.lookup_ops.append(None)

elif cls == VarAggregation:
if MeanAggregation in self.aggr_index:
self.reduce_ops.append(self.REDUCE[cls])
self.lookup_ops.append(
(MeanAggregation, self.aggr_index[MeanAggregation]))
elif SumAggregation in self.aggr_index:
self.reduce_ops.append(self.REDUCE[cls])
self.lookup_ops.append(
(SumAggregation, self.aggr_index[SumAggregation]))
else:
self.reduce_ops.append(self.REDUCE[cls])
self.lookup_ops.append(None)

elif cls == StdAggregation:
# Directly use output of `VarAggregation`:
if VarAggregation in self.aggr_index:
self.reduce_ops.append(None)
self.lookup_ops.append(
(VarAggregation, self.aggr_index[VarAggregation]))
elif MeanAggregation in self.aggr_index:
self.reduce_ops.append(self.REDUCE[cls])
self.lookup_ops.append(
(MeanAggregation, self.aggr_index[MeanAggregation]))
elif SumAggregation in self.aggr_index:
self.reduce_ops.append(self.REDUCE[cls])
self.lookup_ops.append(
(SumAggregation, self.aggr_index[SumAggregation]))
else:
self.reduce_ops.append(self.REDUCE[cls])
self.lookup_ops.append(None)

else:
self.reduce_ops.append(self.REDUCE[cls])
self.lookup_ops.append(None)

def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:

# Assert two-dimensional input for now to simplify computation:
# TODO refactor this to support any dimension.
self.assert_index_present(index)
self.assert_two_dimensional_input(x, dim)

if self.need_degree:
count = x.new_zeros(dim_size)
count.scatter_add_(0, index, x.new_ones(x.size(0)))
if self.requires_mask:
mask = count == 0
count = count.clamp_(min=1).view(-1, 1)

elif self.requires_mask: # Mask to set non-existing indicses to zero:
mask = x.new_ones(dim_size, dtype=torch.bool)
mask[index] = False

num_feats = x.size(-1)
index = index.view(-1, 1).expand(-1, num_feats)

#######################################################################

outs: List[Optional[Tensor]] = []

# Iterate over all reduction ops to compute first results:
for i, reduce in enumerate(self.reduce_ops):
if reduce is None:
outs.append(None)
continue

src = x * x if reduce == 'pow_sum' else x
reduce = 'sum' if reduce == 'pow_sum' else reduce

fill_value = 0.0
if reduce == 'amin':
fill_value = float('inf')
elif reduce == 'amax':
fill_value = float('-inf')
elif reduce == 'prod':
fill_value = 1.0

# `include_self=True` + manual masking leads to faster runtime:
out = x.new_full((dim_size, num_feats), fill_value)
out.scatter_reduce_(0, index, src, reduce, include_self=True)
if fill_value != 0.0:
out = out.masked_fill(mask.view(-1, 1), 0.0)
outs.append(out)

#######################################################################

# Compute `MeanAggregation` first to be able to re-use it:
i = self.aggr_index.get(MeanAggregation)
if i is not None:
if self.lookup_ops[i] is None:
sum_ = outs[i]
else:
tmp_aggr, j = self.lookup_ops[i]
assert tmp_aggr == SumAggregation
sum_ = outs[j]

outs[i] = sum_ / count

# Compute `VarAggregation` second to be able to re-use it:
i = self.aggr_index.get(VarAggregation)
if i is not None:
if self.lookup_ops[i] is None:
mean = x.new_zeros(dim_size, num_feats)
mean.scatter_reduce_(0, index, x, 'sum', include_self=True)
mean = mean / count
else:
tmp_aggr, j = self.lookup_ops[i]
if tmp_aggr == SumAggregation:
mean = outs[j] / count
elif tmp_aggr == MeanAggregation:
mean = outs[j]
else:
raise NotImplementedError

pow_sum = outs[i]
outs[i] = (pow_sum / count) - (mean * mean)

# Compute `StdAggregation` last:
i = self.aggr_index.get(StdAggregation)
if i is not None:
var = None
if self.lookup_ops[i] is None:
pow_sum = outs[i]
mean = x.new_zeros(dim_size, num_feats)
mean.scatter_reduce_(0, index, x, 'sum', include_self=True)
mean = mean / count
else:
tmp_aggr, j = self.lookup_ops[i]
if tmp_aggr == VarAggregation:
var = outs[j]
elif tmp_aggr == SumAggregation:
pow_sum = outs[i]
mean = outs[j] / count
elif tmp_aggr == MeanAggregation:
pow_sum = outs[i]
mean = outs[j]
else:
raise NotImplementedError

if var is None:
var = (pow_sum / count) - (mean * mean)

outs[i] = (var.relu() + 1e-5).sqrt()

#######################################################################

out = torch.cat(outs, dim=-1)

return out