From 6065e105dcf02590c7fe7559b4fedbc9c2b9a1c1 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Wed, 11 Jan 2023 16:36:22 +0100 Subject: [PATCH] Drop `torch-scatter` dependency (part 3) (#6399) --- CHANGELOG.md | 2 +- test/transforms/test_local_degree_profile.py | 18 +++++++++------ torch_geometric/transforms/gdc.py | 10 ++++----- .../transforms/generate_mesh_normals.py | 4 ++-- torch_geometric/transforms/grid_sampling.py | 6 ++--- torch_geometric/transforms/line_graph.py | 13 +++++------ torch_geometric/transforms/local_cartesian.py | 4 ++-- .../transforms/local_degree_profile.py | 22 +++++++------------ torch_geometric/transforms/to_superpixels.py | 6 ++--- torch_geometric/utils/augmentation.py | 5 +++-- torch_geometric/utils/get_laplacian.py | 5 ++--- torch_geometric/utils/get_mesh_laplacian.py | 7 +++--- torch_geometric/utils/homophily.py | 9 ++++---- torch_geometric/utils/to_dense_batch.py | 7 +++--- 14 files changed, 57 insertions(+), 61 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ccac7dd5d61..657bf4c71180 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -70,7 +70,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Optimized `utils.softmax` implementation ([#6113](https://github.com/pyg-team/pytorch_geometric/pull/6113), [#6155](https://github.com/pyg-team/pytorch_geometric/pull/6155)) - Optimized `topk` implementation for large enough graphs ([#6123](https://github.com/pyg-team/pytorch_geometric/pull/6123)) ### Removed -- Removed most of the `torch-scatter` dependencies ([#6394](https://github.com/pyg-team/pytorch_geometric/pull/6394),[#6395](https://github.com/pyg-team/pytorch_geometric/pull/6395)) +- Removed most of the `torch-scatter` dependencies ([#6394](https://github.com/pyg-team/pytorch_geometric/pull/6394), [#6395](https://github.com/pyg-team/pytorch_geometric/pull/6395), [#6399](https://github.com/pyg-team/pytorch_geometric/pull/6399)) - Removed the deprecated classes `GNNExplainer` and `Explainer` from `nn.models` ([#6382](https://github.com/pyg-team/pytorch_geometric/pull/6382)) - Removed `target_index` argument in the `Explainer` interface ([#6270](https://github.com/pyg-team/pytorch_geometric/pull/6270)) - Removed `Aggregation.set_validate_args` option ([#6175](https://github.com/pyg-team/pytorch_geometric/pull/6175)) diff --git a/test/transforms/test_local_degree_profile.py b/test/transforms/test_local_degree_profile.py index 768d532e58eb..44a2e5bf1f3d 100644 --- a/test/transforms/test_local_degree_profile.py +++ b/test/transforms/test_local_degree_profile.py @@ -10,14 +10,18 @@ def test_target_indegree(): edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) x = torch.Tensor([[1], [1], [1], [1]]) # One isolated node. - expected = [[1, 2, 2, 2, 0], [2, 1, 1, 1, 0], [1, 2, 2, 2, 0], - [0, 0, 0, 0, 0]] + expected = torch.tensor([ + [1, 2, 2, 2, 0], + [2, 1, 1, 1, 0], + [1, 2, 2, 2, 0], + [0, 0, 0, 0, 0], + ], dtype=torch.float) - data = Data(edge_index=edge_index, pos=x) + data = Data(edge_index=edge_index, num_nodes=x.size(0)) data = LocalDegreeProfile()(data) - assert data.x.tolist() == expected + assert torch.allclose(data.x, expected, atol=1e-2) - data.x = x + data = Data(edge_index=edge_index, x=x) data = LocalDegreeProfile()(data) - assert data.x[:, 1:].tolist() == expected - assert data.x[:, 0].tolist() == [1, 1, 1, 1] + assert torch.allclose(data.x[:, :1], x) + assert torch.allclose(data.x[:, 1:], expected, atol=1e-2) diff --git a/torch_geometric/transforms/gdc.py b/torch_geometric/transforms/gdc.py index a2b1ee36cef5..b63ad758d877 100644 --- a/torch_geometric/transforms/gdc.py +++ b/torch_geometric/transforms/gdc.py @@ -4,7 +4,6 @@ import torch from scipy.linalg import expm from torch import Tensor -from torch_scatter import scatter_add from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform @@ -13,6 +12,7 @@ add_self_loops, coalesce, is_undirected, + scatter, to_dense_adj, ) @@ -166,19 +166,19 @@ def transition_matrix( """ if normalization == 'sym': row, col = edge_index - deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) + deg = scatter(edge_weight, col, 0, num_nodes, reduce='sum') deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] elif normalization == 'col': _, col = edge_index - deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) + deg = scatter(edge_weight, col, 0, num_nodes, reduce='sum') deg_inv = 1. / deg deg_inv[deg_inv == float('inf')] = 0 edge_weight = edge_weight * deg_inv[col] elif normalization == 'row': row, _ = edge_index - deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) + deg = scatter(edge_weight, row, 0, num_nodes, reduce='sum') deg_inv = 1. / deg deg_inv[deg_inv == float('inf')] = 0 edge_weight = edge_weight * deg_inv[row] @@ -300,7 +300,7 @@ def diffusion_matrix_approx( if normalization == 'sym': # Calculate original degrees. _, col = edge_index - deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) + deg = scatter(edge_weight, col, 0, num_nodes, reduce='sum') edge_index_np = edge_index.cpu().numpy() diff --git a/torch_geometric/transforms/generate_mesh_normals.py b/torch_geometric/transforms/generate_mesh_normals.py index 0718edc14a1b..9cf5d6c5147c 100644 --- a/torch_geometric/transforms/generate_mesh_normals.py +++ b/torch_geometric/transforms/generate_mesh_normals.py @@ -1,10 +1,10 @@ import torch import torch.nn.functional as F -from torch_scatter import scatter_add from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform +from torch_geometric.utils import scatter @functional_transform('generate_mesh_normals') @@ -22,7 +22,7 @@ def __call__(self, data: Data) -> Data: idx = torch.cat([face[0], face[1], face[2]], dim=0) face_norm = face_norm.repeat(3, 1) - norm = scatter_add(face_norm, idx, dim=0, dim_size=pos.size(0)) + norm = scatter(face_norm, idx, 0, pos.size(0), reduce='sum') norm = F.normalize(norm, p=2, dim=-1) # [N, 3] data.norm = norm diff --git a/torch_geometric/transforms/grid_sampling.py b/torch_geometric/transforms/grid_sampling.py index abcd678d2c07..b549dc19f6ce 100644 --- a/torch_geometric/transforms/grid_sampling.py +++ b/torch_geometric/transforms/grid_sampling.py @@ -4,12 +4,12 @@ import torch import torch.nn.functional as F from torch import Tensor -from torch_scatter import scatter_add, scatter_mean import torch_geometric from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform +from torch_geometric.utils import scatter @functional_transform('grid_sampling') @@ -54,12 +54,12 @@ def __call__(self, data: Data) -> Data: if torch.is_tensor(item) and item.size(0) == num_nodes: if key == 'y': item = F.one_hot(item) - item = scatter_add(item, c, dim=0) + item = scatter(item, c, dim=0, reduce='sum') data[key] = item.argmax(dim=-1) elif key == 'batch': data[key] = item[perm] else: - data[key] = scatter_mean(item, c, dim=0) + data[key] = scatter(item, c, dim=0, reduce='mean') return data diff --git a/torch_geometric/transforms/line_graph.py b/torch_geometric/transforms/line_graph.py index 142e3c137bf8..9d63ae388a06 100644 --- a/torch_geometric/transforms/line_graph.py +++ b/torch_geometric/transforms/line_graph.py @@ -1,10 +1,9 @@ import torch -from torch_scatter import scatter_add from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform -from torch_geometric.utils import coalesce, remove_self_loops +from torch_geometric.utils import coalesce, remove_self_loops, scatter @functional_transform('line_graph') @@ -47,8 +46,8 @@ def __call__(self, data: Data) -> Data: if self.force_directed or data.is_directed(): i = torch.arange(row.size(0), dtype=torch.long, device=row.device) - count = scatter_add(torch.ones_like(row), row, dim=0, - dim_size=data.num_nodes) + count = scatter(torch.ones_like(row), row, dim=0, + dim_size=data.num_nodes, reduce='sum') cumsum = torch.cat([count.new_zeros(1), count.cumsum(0)], dim=0) cols = [ @@ -79,8 +78,8 @@ def __call__(self, data: Data) -> Data: ) # Compute new edge indices according to `i`. - count = scatter_add(torch.ones_like(row), row, dim=0, - dim_size=data.num_nodes) + count = scatter(torch.ones_like(row), row, dim=0, + dim_size=data.num_nodes, reduce='sum') joints = torch.split(i, count.tolist()) def generate_grid(x): @@ -95,7 +94,7 @@ def generate_grid(x): joints = coalesce(joints, num_nodes=N) if edge_attr is not None: - data.x = scatter_add(edge_attr, i, dim=0, dim_size=N) + data.x = scatter(edge_attr, i, dim=0, dim_size=N, reduce='sum') data.edge_index = joints data.num_nodes = edge_index.size(1) // 2 diff --git a/torch_geometric/transforms/local_cartesian.py b/torch_geometric/transforms/local_cartesian.py index 69e6014f4f6b..fdc715a4bc00 100644 --- a/torch_geometric/transforms/local_cartesian.py +++ b/torch_geometric/transforms/local_cartesian.py @@ -1,9 +1,9 @@ import torch -from torch_scatter import scatter_max from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform +from torch_geometric.utils import scatter @functional_transform('local_cartesian') @@ -29,7 +29,7 @@ def __call__(self, data: Data) -> Data: cart = pos[row] - pos[col] cart = cart.view(-1, 1) if cart.dim() == 1 else cart - max_value, _ = scatter_max(cart.abs(), col, 0, dim_size=pos.size(0)) + max_value = scatter(cart.abs(), col, 0, pos.size(0), reduce='max') max_value = max_value.max(dim=-1, keepdim=True)[0] if self.norm: diff --git a/torch_geometric/transforms/local_degree_profile.py b/torch_geometric/transforms/local_degree_profile.py index 6fa0a47f4c41..f517ff9c45cb 100644 --- a/torch_geometric/transforms/local_degree_profile.py +++ b/torch_geometric/transforms/local_degree_profile.py @@ -1,5 +1,4 @@ import torch -from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_std from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform @@ -21,26 +20,21 @@ class LocalDegreeProfile(BaseTransform): to the node features, where :math:`DN(i) = \{ \deg(j) \mid j \in \mathcal{N}(i) \}`. """ + def __init__(self): + from torch_geometric.nn.aggr.fused import FusedAggregation + self.aggr = FusedAggregation(['min', 'max', 'mean', 'std']) + def __call__(self, data: Data) -> Data: row, col = data.edge_index N = data.num_nodes - deg = degree(row, N, dtype=torch.float) - deg_col = deg[col] - - min_deg, _ = scatter_min(deg_col, row, dim_size=N) - min_deg[min_deg > 10000] = 0 - max_deg, _ = scatter_max(deg_col, row, dim_size=N) - max_deg[max_deg < -10000] = 0 - mean_deg = scatter_mean(deg_col, row, dim_size=N) - std_deg = scatter_std(deg_col, row, dim_size=N) - - x = torch.stack([deg, min_deg, max_deg, mean_deg, std_deg], dim=1) + deg = degree(row, N, dtype=torch.float).view(-1, 1) + xs = [deg] + self.aggr(deg[col], row, dim_size=N) if data.x is not None: data.x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x - data.x = torch.cat([data.x, x], dim=-1) + data.x = torch.cat([data.x] + xs, dim=-1) else: - data.x = x + data.x = torch.cat(xs, dim=-1) return data diff --git a/torch_geometric/transforms/to_superpixels.py b/torch_geometric/transforms/to_superpixels.py index 7ce90f139a65..72baaa020eae 100644 --- a/torch_geometric/transforms/to_superpixels.py +++ b/torch_geometric/transforms/to_superpixels.py @@ -1,10 +1,10 @@ import torch from torch import Tensor -from torch_scatter import scatter_mean from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform +from torch_geometric.utils import scatter @functional_transform('to_slic') @@ -50,7 +50,7 @@ def __call__(self, img: Tensor) -> Data: seg = slic(img.to(torch.double).numpy(), start_label=0, **self.kwargs) seg = torch.from_numpy(seg) - x = scatter_mean(img.view(h * w, c), seg.view(h * w), dim=0) + x = scatter(img.view(h * w, c), seg.view(h * w), dim=0, reduce='mean') pos_y = torch.arange(h, dtype=torch.float) pos_y = pos_y.view(-1, 1).repeat(1, w).view(h * w) @@ -58,7 +58,7 @@ def __call__(self, img: Tensor) -> Data: pos_x = pos_x.view(1, -1).repeat(h, 1).view(h * w) pos = torch.stack([pos_x, pos_y], dim=-1) - pos = scatter_mean(pos, seg.view(h * w), dim=0) + pos = scatter(pos, seg.view(h * w), dim=0, reduce='mean') data = Data(x=x, pos=pos) diff --git a/torch_geometric/utils/augmentation.py b/torch_geometric/utils/augmentation.py index fb5ba5f96fa1..4e9e47ddea0c 100644 --- a/torch_geometric/utils/augmentation.py +++ b/torch_geometric/utils/augmentation.py @@ -2,7 +2,8 @@ import torch from torch import Tensor -from torch_scatter import scatter_add + +from torch_geometric.utils import scatter from .num_nodes import maybe_num_nodes @@ -58,7 +59,7 @@ def shuffle_node(x: Tensor, batch: Optional[Tensor] = None, if batch is None: perm = torch.randperm(x.size(0), device=x.device) return x[perm], perm - num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0) + num_nodes = scatter(batch.new_ones(x.size(0)), batch, dim=0, reduce='sum') cumsum = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)]) perm = torch.cat([ torch.randperm(n, device=x.device) + offset diff --git a/torch_geometric/utils/get_laplacian.py b/torch_geometric/utils/get_laplacian.py index bb1a4af7c417..d5cd4b4ae554 100644 --- a/torch_geometric/utils/get_laplacian.py +++ b/torch_geometric/utils/get_laplacian.py @@ -2,10 +2,9 @@ import torch from torch import Tensor -from torch_scatter import scatter_add from torch_geometric.typing import OptTensor -from torch_geometric.utils import add_self_loops, remove_self_loops +from torch_geometric.utils import add_self_loops, remove_self_loops, scatter from .num_nodes import maybe_num_nodes @@ -70,7 +69,7 @@ def get_laplacian( num_nodes = maybe_num_nodes(edge_index, num_nodes) row, col = edge_index[0], edge_index[1] - deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) + deg = scatter(edge_weight, row, 0, dim_size=num_nodes, reduce='sum') if normalization is None: # L = D - A. diff --git a/torch_geometric/utils/get_mesh_laplacian.py b/torch_geometric/utils/get_mesh_laplacian.py index 2a44218ee881..ef68a2673d65 100644 --- a/torch_geometric/utils/get_mesh_laplacian.py +++ b/torch_geometric/utils/get_mesh_laplacian.py @@ -2,9 +2,8 @@ import torch from torch import Tensor -from torch_scatter import scatter_add -from torch_geometric.utils import add_self_loops, coalesce +from torch_geometric.utils import add_self_loops, coalesce, scatter def get_mesh_laplacian(pos: Tensor, face: Tensor) -> Tuple[Tensor, Tensor]: @@ -66,8 +65,8 @@ def add_angles(left, centre, right): # Compute the diagonal part: row, col = edge_index - cot_deg = scatter_add(cot_weight, row, dim=0, dim_size=num_nodes) - area_deg = scatter_add(area_weight, row, dim=0, dim_size=num_nodes) + cot_deg = scatter(cot_weight, row, 0, num_nodes, reduce='sum') + area_deg = scatter(area_weight, row, 0, num_nodes, reduce='sum') deg = cot_deg / area_deg edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) edge_weight = torch.cat([-cot_weight, deg], dim=0) diff --git a/torch_geometric/utils/homophily.py b/torch_geometric/utils/homophily.py index 334b0352d1a1..ee05b56da8dc 100644 --- a/torch_geometric/utils/homophily.py +++ b/torch_geometric/utils/homophily.py @@ -2,10 +2,9 @@ import torch from torch import Tensor -from torch_scatter import scatter_mean from torch_geometric.typing import Adj, OptTensor, SparseTensor -from torch_geometric.utils import degree +from torch_geometric.utils import degree, scatter def homophily(edge_index: Adj, y: Tensor, batch: OptTensor = None, @@ -95,16 +94,16 @@ def homophily(edge_index: Adj, y: Tensor, batch: OptTensor = None, return float(out.mean()) else: dim_size = int(batch.max()) + 1 - return scatter_mean(out, batch[col], dim=0, dim_size=dim_size) + return scatter(out, batch[col], 0, dim_size, reduce='mean') elif method == 'node': out = torch.zeros(row.size(0), device=row.device) out[y[row] == y[col]] = 1. - out = scatter_mean(out, col, 0, dim_size=y.size(0)) + out = scatter(out, col, 0, dim_size=y.size(0), reduce='mean') if batch is None: return float(out.mean()) else: - return scatter_mean(out, batch, dim=0) + return scatter(out, batch, dim=0, reduce='mean') elif method == 'edge_insensitive': assert y.dim() == 1 diff --git a/torch_geometric/utils/to_dense_batch.py b/torch_geometric/utils/to_dense_batch.py index 608d894fb0bf..a13bc06db1d6 100644 --- a/torch_geometric/utils/to_dense_batch.py +++ b/torch_geometric/utils/to_dense_batch.py @@ -2,7 +2,8 @@ import torch from torch import Tensor -from torch_scatter import scatter_add + +from torch_geometric.utils import scatter def to_dense_batch(x: Tensor, batch: Optional[Tensor] = None, @@ -94,8 +95,8 @@ def to_dense_batch(x: Tensor, batch: Optional[Tensor] = None, if batch_size is None: batch_size = int(batch.max()) + 1 - num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0, - dim_size=batch_size) + num_nodes = scatter(batch.new_ones(x.size(0)), batch, dim=0, + dim_size=batch_size, reduce='sum') cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)]) filter_nodes = False