Skip to content

Commit

Permalink
Drop torch-scatter dependency (part 3) (#6399)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Jan 11, 2023
1 parent f710570 commit 6065e10
Show file tree
Hide file tree
Showing 14 changed files with 57 additions and 61 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Optimized `utils.softmax` implementation ([#6113](https://github.com/pyg-team/pytorch_geometric/pull/6113), [#6155](https://github.com/pyg-team/pytorch_geometric/pull/6155))
- Optimized `topk` implementation for large enough graphs ([#6123](https://github.com/pyg-team/pytorch_geometric/pull/6123))
### Removed
- Removed most of the `torch-scatter` dependencies ([#6394](https://github.com/pyg-team/pytorch_geometric/pull/6394),[#6395](https://github.com/pyg-team/pytorch_geometric/pull/6395))
- Removed most of the `torch-scatter` dependencies ([#6394](https://github.com/pyg-team/pytorch_geometric/pull/6394), [#6395](https://github.com/pyg-team/pytorch_geometric/pull/6395), [#6399](https://github.com/pyg-team/pytorch_geometric/pull/6399))
- Removed the deprecated classes `GNNExplainer` and `Explainer` from `nn.models` ([#6382](https://github.com/pyg-team/pytorch_geometric/pull/6382))
- Removed `target_index` argument in the `Explainer` interface ([#6270](https://github.com/pyg-team/pytorch_geometric/pull/6270))
- Removed `Aggregation.set_validate_args` option ([#6175](https://github.com/pyg-team/pytorch_geometric/pull/6175))
Expand Down
18 changes: 11 additions & 7 deletions test/transforms/test_local_degree_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@ def test_target_indegree():
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
x = torch.Tensor([[1], [1], [1], [1]]) # One isolated node.

expected = [[1, 2, 2, 2, 0], [2, 1, 1, 1, 0], [1, 2, 2, 2, 0],
[0, 0, 0, 0, 0]]
expected = torch.tensor([
[1, 2, 2, 2, 0],
[2, 1, 1, 1, 0],
[1, 2, 2, 2, 0],
[0, 0, 0, 0, 0],
], dtype=torch.float)

data = Data(edge_index=edge_index, pos=x)
data = Data(edge_index=edge_index, num_nodes=x.size(0))
data = LocalDegreeProfile()(data)
assert data.x.tolist() == expected
assert torch.allclose(data.x, expected, atol=1e-2)

data.x = x
data = Data(edge_index=edge_index, x=x)
data = LocalDegreeProfile()(data)
assert data.x[:, 1:].tolist() == expected
assert data.x[:, 0].tolist() == [1, 1, 1, 1]
assert torch.allclose(data.x[:, :1], x)
assert torch.allclose(data.x[:, 1:], expected, atol=1e-2)
10 changes: 5 additions & 5 deletions torch_geometric/transforms/gdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
from scipy.linalg import expm
from torch import Tensor
from torch_scatter import scatter_add

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
Expand All @@ -13,6 +12,7 @@
add_self_loops,
coalesce,
is_undirected,
scatter,
to_dense_adj,
)

Expand Down Expand Up @@ -166,19 +166,19 @@ def transition_matrix(
"""
if normalization == 'sym':
row, col = edge_index
deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
deg = scatter(edge_weight, col, 0, num_nodes, reduce='sum')
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
elif normalization == 'col':
_, col = edge_index
deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
deg = scatter(edge_weight, col, 0, num_nodes, reduce='sum')
deg_inv = 1. / deg
deg_inv[deg_inv == float('inf')] = 0
edge_weight = edge_weight * deg_inv[col]
elif normalization == 'row':
row, _ = edge_index
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
deg = scatter(edge_weight, row, 0, num_nodes, reduce='sum')
deg_inv = 1. / deg
deg_inv[deg_inv == float('inf')] = 0
edge_weight = edge_weight * deg_inv[row]
Expand Down Expand Up @@ -300,7 +300,7 @@ def diffusion_matrix_approx(
if normalization == 'sym':
# Calculate original degrees.
_, col = edge_index
deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
deg = scatter(edge_weight, col, 0, num_nodes, reduce='sum')

edge_index_np = edge_index.cpu().numpy()

Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/transforms/generate_mesh_normals.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import torch
import torch.nn.functional as F
from torch_scatter import scatter_add

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import scatter


@functional_transform('generate_mesh_normals')
Expand All @@ -22,7 +22,7 @@ def __call__(self, data: Data) -> Data:
idx = torch.cat([face[0], face[1], face[2]], dim=0)
face_norm = face_norm.repeat(3, 1)

norm = scatter_add(face_norm, idx, dim=0, dim_size=pos.size(0))
norm = scatter(face_norm, idx, 0, pos.size(0), reduce='sum')
norm = F.normalize(norm, p=2, dim=-1) # [N, 3]

data.norm = norm
Expand Down
6 changes: 3 additions & 3 deletions torch_geometric/transforms/grid_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch_scatter import scatter_add, scatter_mean

import torch_geometric
from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import scatter


@functional_transform('grid_sampling')
Expand Down Expand Up @@ -54,12 +54,12 @@ def __call__(self, data: Data) -> Data:
if torch.is_tensor(item) and item.size(0) == num_nodes:
if key == 'y':
item = F.one_hot(item)
item = scatter_add(item, c, dim=0)
item = scatter(item, c, dim=0, reduce='sum')
data[key] = item.argmax(dim=-1)
elif key == 'batch':
data[key] = item[perm]
else:
data[key] = scatter_mean(item, c, dim=0)
data[key] = scatter(item, c, dim=0, reduce='mean')

return data

Expand Down
13 changes: 6 additions & 7 deletions torch_geometric/transforms/line_graph.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import torch
from torch_scatter import scatter_add

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import coalesce, remove_self_loops
from torch_geometric.utils import coalesce, remove_self_loops, scatter


@functional_transform('line_graph')
Expand Down Expand Up @@ -47,8 +46,8 @@ def __call__(self, data: Data) -> Data:
if self.force_directed or data.is_directed():
i = torch.arange(row.size(0), dtype=torch.long, device=row.device)

count = scatter_add(torch.ones_like(row), row, dim=0,
dim_size=data.num_nodes)
count = scatter(torch.ones_like(row), row, dim=0,
dim_size=data.num_nodes, reduce='sum')
cumsum = torch.cat([count.new_zeros(1), count.cumsum(0)], dim=0)

cols = [
Expand Down Expand Up @@ -79,8 +78,8 @@ def __call__(self, data: Data) -> Data:
)

# Compute new edge indices according to `i`.
count = scatter_add(torch.ones_like(row), row, dim=0,
dim_size=data.num_nodes)
count = scatter(torch.ones_like(row), row, dim=0,
dim_size=data.num_nodes, reduce='sum')
joints = torch.split(i, count.tolist())

def generate_grid(x):
Expand All @@ -95,7 +94,7 @@ def generate_grid(x):
joints = coalesce(joints, num_nodes=N)

if edge_attr is not None:
data.x = scatter_add(edge_attr, i, dim=0, dim_size=N)
data.x = scatter(edge_attr, i, dim=0, dim_size=N, reduce='sum')
data.edge_index = joints
data.num_nodes = edge_index.size(1) // 2

Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/transforms/local_cartesian.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch
from torch_scatter import scatter_max

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import scatter


@functional_transform('local_cartesian')
Expand All @@ -29,7 +29,7 @@ def __call__(self, data: Data) -> Data:
cart = pos[row] - pos[col]
cart = cart.view(-1, 1) if cart.dim() == 1 else cart

max_value, _ = scatter_max(cart.abs(), col, 0, dim_size=pos.size(0))
max_value = scatter(cart.abs(), col, 0, pos.size(0), reduce='max')
max_value = max_value.max(dim=-1, keepdim=True)[0]

if self.norm:
Expand Down
22 changes: 8 additions & 14 deletions torch_geometric/transforms/local_degree_profile.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_std

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
Expand All @@ -21,26 +20,21 @@ class LocalDegreeProfile(BaseTransform):
to the node features, where :math:`DN(i) = \{ \deg(j) \mid j \in
\mathcal{N}(i) \}`.
"""
def __init__(self):
from torch_geometric.nn.aggr.fused import FusedAggregation
self.aggr = FusedAggregation(['min', 'max', 'mean', 'std'])

def __call__(self, data: Data) -> Data:
row, col = data.edge_index
N = data.num_nodes

deg = degree(row, N, dtype=torch.float)
deg_col = deg[col]

min_deg, _ = scatter_min(deg_col, row, dim_size=N)
min_deg[min_deg > 10000] = 0
max_deg, _ = scatter_max(deg_col, row, dim_size=N)
max_deg[max_deg < -10000] = 0
mean_deg = scatter_mean(deg_col, row, dim_size=N)
std_deg = scatter_std(deg_col, row, dim_size=N)

x = torch.stack([deg, min_deg, max_deg, mean_deg, std_deg], dim=1)
deg = degree(row, N, dtype=torch.float).view(-1, 1)
xs = [deg] + self.aggr(deg[col], row, dim_size=N)

if data.x is not None:
data.x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x
data.x = torch.cat([data.x, x], dim=-1)
data.x = torch.cat([data.x] + xs, dim=-1)
else:
data.x = x
data.x = torch.cat(xs, dim=-1)

return data
6 changes: 3 additions & 3 deletions torch_geometric/transforms/to_superpixels.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import torch
from torch import Tensor
from torch_scatter import scatter_mean

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import scatter


@functional_transform('to_slic')
Expand Down Expand Up @@ -50,15 +50,15 @@ def __call__(self, img: Tensor) -> Data:
seg = slic(img.to(torch.double).numpy(), start_label=0, **self.kwargs)
seg = torch.from_numpy(seg)

x = scatter_mean(img.view(h * w, c), seg.view(h * w), dim=0)
x = scatter(img.view(h * w, c), seg.view(h * w), dim=0, reduce='mean')

pos_y = torch.arange(h, dtype=torch.float)
pos_y = pos_y.view(-1, 1).repeat(1, w).view(h * w)
pos_x = torch.arange(w, dtype=torch.float)
pos_x = pos_x.view(1, -1).repeat(h, 1).view(h * w)

pos = torch.stack([pos_x, pos_y], dim=-1)
pos = scatter_mean(pos, seg.view(h * w), dim=0)
pos = scatter(pos, seg.view(h * w), dim=0, reduce='mean')

data = Data(x=x, pos=pos)

Expand Down
5 changes: 3 additions & 2 deletions torch_geometric/utils/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import torch
from torch import Tensor
from torch_scatter import scatter_add

from torch_geometric.utils import scatter

from .num_nodes import maybe_num_nodes

Expand Down Expand Up @@ -58,7 +59,7 @@ def shuffle_node(x: Tensor, batch: Optional[Tensor] = None,
if batch is None:
perm = torch.randperm(x.size(0), device=x.device)
return x[perm], perm
num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
num_nodes = scatter(batch.new_ones(x.size(0)), batch, dim=0, reduce='sum')
cumsum = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)])
perm = torch.cat([
torch.randperm(n, device=x.device) + offset
Expand Down
5 changes: 2 additions & 3 deletions torch_geometric/utils/get_laplacian.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import torch
from torch import Tensor
from torch_scatter import scatter_add

from torch_geometric.typing import OptTensor
from torch_geometric.utils import add_self_loops, remove_self_loops
from torch_geometric.utils import add_self_loops, remove_self_loops, scatter

from .num_nodes import maybe_num_nodes

Expand Down Expand Up @@ -70,7 +69,7 @@ def get_laplacian(
num_nodes = maybe_num_nodes(edge_index, num_nodes)

row, col = edge_index[0], edge_index[1]
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
deg = scatter(edge_weight, row, 0, dim_size=num_nodes, reduce='sum')

if normalization is None:
# L = D - A.
Expand Down
7 changes: 3 additions & 4 deletions torch_geometric/utils/get_mesh_laplacian.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

import torch
from torch import Tensor
from torch_scatter import scatter_add

from torch_geometric.utils import add_self_loops, coalesce
from torch_geometric.utils import add_self_loops, coalesce, scatter


def get_mesh_laplacian(pos: Tensor, face: Tensor) -> Tuple[Tensor, Tensor]:
Expand Down Expand Up @@ -66,8 +65,8 @@ def add_angles(left, centre, right):

# Compute the diagonal part:
row, col = edge_index
cot_deg = scatter_add(cot_weight, row, dim=0, dim_size=num_nodes)
area_deg = scatter_add(area_weight, row, dim=0, dim_size=num_nodes)
cot_deg = scatter(cot_weight, row, 0, num_nodes, reduce='sum')
area_deg = scatter(area_weight, row, 0, num_nodes, reduce='sum')
deg = cot_deg / area_deg
edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
edge_weight = torch.cat([-cot_weight, deg], dim=0)
Expand Down
9 changes: 4 additions & 5 deletions torch_geometric/utils/homophily.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import torch
from torch import Tensor
from torch_scatter import scatter_mean

from torch_geometric.typing import Adj, OptTensor, SparseTensor
from torch_geometric.utils import degree
from torch_geometric.utils import degree, scatter


def homophily(edge_index: Adj, y: Tensor, batch: OptTensor = None,
Expand Down Expand Up @@ -95,16 +94,16 @@ def homophily(edge_index: Adj, y: Tensor, batch: OptTensor = None,
return float(out.mean())
else:
dim_size = int(batch.max()) + 1
return scatter_mean(out, batch[col], dim=0, dim_size=dim_size)
return scatter(out, batch[col], 0, dim_size, reduce='mean')

elif method == 'node':
out = torch.zeros(row.size(0), device=row.device)
out[y[row] == y[col]] = 1.
out = scatter_mean(out, col, 0, dim_size=y.size(0))
out = scatter(out, col, 0, dim_size=y.size(0), reduce='mean')
if batch is None:
return float(out.mean())
else:
return scatter_mean(out, batch, dim=0)
return scatter(out, batch, dim=0, reduce='mean')

elif method == 'edge_insensitive':
assert y.dim() == 1
Expand Down
7 changes: 4 additions & 3 deletions torch_geometric/utils/to_dense_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import torch
from torch import Tensor
from torch_scatter import scatter_add

from torch_geometric.utils import scatter


def to_dense_batch(x: Tensor, batch: Optional[Tensor] = None,
Expand Down Expand Up @@ -94,8 +95,8 @@ def to_dense_batch(x: Tensor, batch: Optional[Tensor] = None,
if batch_size is None:
batch_size = int(batch.max()) + 1

num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0,
dim_size=batch_size)
num_nodes = scatter(batch.new_ones(x.size(0)), batch, dim=0,
dim_size=batch_size, reduce='sum')
cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)])

filter_nodes = False
Expand Down

0 comments on commit 6065e10

Please sign in to comment.