diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e975eb4afd7..45d299bf2c6d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -113,7 +113,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed -- `torch-sparse` is now an optional dependency ([#6625](https://github.com/pyg-team/pytorch_geometric/pull/6625), [#6626](https://github.com/pyg-team/pytorch_geometric/pull/6626)) +- `torch-sparse` is now an optional dependency ([#6625](https://github.com/pyg-team/pytorch_geometric/pull/6625), [#6626](https://github.com/pyg-team/pytorch_geometric/pull/6626), [#6627](https://github.com/pyg-team/pytorch_geometric/pull/6627)) - Removed most of the `torch-scatter` dependencies ([#6394](https://github.com/pyg-team/pytorch_geometric/pull/6394), [#6395](https://github.com/pyg-team/pytorch_geometric/pull/6395), [#6399](https://github.com/pyg-team/pytorch_geometric/pull/6399), [#6400](https://github.com/pyg-team/pytorch_geometric/pull/6400), [#6615](https://github.com/pyg-team/pytorch_geometric/pull/6615), [#6617](https://github.com/pyg-team/pytorch_geometric/pull/6617)) - Removed the deprecated classes `GNNExplainer` and `Explainer` from `nn.models` ([#6382](https://github.com/pyg-team/pytorch_geometric/pull/6382)) - Removed `target_index` argument in the `Explainer` interface ([#6270](https://github.com/pyg-team/pytorch_geometric/pull/6270)) diff --git a/docs/source/advanced/sparse_tensor.rst b/docs/source/advanced/sparse_tensor.rst index 41c9aed13739..abb703be1d79 100644 --- a/docs/source/advanced/sparse_tensor.rst +++ b/docs/source/advanced/sparse_tensor.rst @@ -113,7 +113,7 @@ With it, the :class:`~torch_geometric.nn.conv.GINConv` layer can now be implemen .. code-block:: python - from torch_sparse import matmul + import torch_sparse class GINConv(MessagePassing): def __init__(self): @@ -127,7 +127,7 @@ With it, the :class:`~torch_geometric.nn.conv.GINConv` layer can now be implemen return x_j def message_and_aggregate(self, adj_t, x): - return matmul(adj_t, x, reduce=self.aggr) + return torch_sparse.matmul(adj_t, x, reduce=self.aggr) Playing around with the new :class:`SparseTensor` format is straightforward since all of our GNNs work with it out-of-the-box. To convert the :obj:`edge_index` format to the newly introduced :class:`SparseTensor` format, you can make use of the :class:`torch_geometric.transforms.ToSparseTensor` transform: diff --git a/torch_geometric/nn/models/graph_unet.py b/torch_geometric/nn/models/graph_unet.py index 37ac88ae51ec..d0748503138d 100644 --- a/torch_geometric/nn/models/graph_unet.py +++ b/torch_geometric/nn/models/graph_unet.py @@ -3,15 +3,10 @@ import torch import torch.nn.functional as F from torch import Tensor -from torch_sparse import spspmm from torch_geometric.nn import GCNConv, TopKPooling -from torch_geometric.typing import OptTensor, PairTensor -from torch_geometric.utils import ( - add_self_loops, - remove_self_loops, - sort_edge_index, -) +from torch_geometric.typing import OptTensor, PairTensor, SparseTensor +from torch_geometric.utils import add_self_loops, remove_self_loops from torch_geometric.utils.repeat import repeat @@ -131,11 +126,11 @@ def augment_adj(self, edge_index: Tensor, edge_weight: Tensor, edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) edge_index, edge_weight = add_self_loops(edge_index, edge_weight, num_nodes=num_nodes) - edge_index, edge_weight = sort_edge_index(edge_index, edge_weight, - num_nodes) - edge_index, edge_weight = spspmm(edge_index, edge_weight, edge_index, - edge_weight, num_nodes, num_nodes, - num_nodes) + adj = SparseTensor.from_edge_index(edge_index, edge_weight, + sparse_sizes=(num_nodes, num_nodes)) + adj = adj @ adj + row, col, edge_weight = adj.coo() + edge_index = torch.stack([row, col], dim=0) edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) return edge_index, edge_weight diff --git a/torch_geometric/nn/models/signed_gcn.py b/torch_geometric/nn/models/signed_gcn.py index dfa5e48f5f99..51b893d30194 100644 --- a/torch_geometric/nn/models/signed_gcn.py +++ b/torch_geometric/nn/models/signed_gcn.py @@ -4,10 +4,10 @@ import torch import torch.nn.functional as F from torch import Tensor -from torch_sparse import coalesce from torch_geometric.nn import SignedConv from torch_geometric.utils import ( + coalesce, negative_sampling, structured_negative_sampling, ) @@ -111,7 +111,7 @@ def create_spectral_features( edge_index = torch.cat([edge_index, torch.stack([col, row])], dim=1) val = torch.cat([val, val], dim=0) - edge_index, val = coalesce(edge_index, val, N, N) + edge_index, val = coalesce(edge_index, val, num_nodes=N) val = val - 1 # Borrowed from: diff --git a/torch_geometric/nn/pool/pool.py b/torch_geometric/nn/pool/pool.py index 04526cafdce5..da076bc1cbef 100644 --- a/torch_geometric/nn/pool/pool.py +++ b/torch_geometric/nn/pool/pool.py @@ -1,9 +1,8 @@ from typing import Optional import torch -from torch_sparse import coalesce -from torch_geometric.utils import remove_self_loops, scatter +from torch_geometric.utils import coalesce, remove_self_loops, scatter def pool_edge(cluster, edge_index, edge_attr: Optional[torch.Tensor] = None): @@ -11,8 +10,7 @@ def pool_edge(cluster, edge_index, edge_attr: Optional[torch.Tensor] = None): edge_index = cluster[edge_index.view(-1)].view(2, -1) edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) if edge_index.numel() > 0: - edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes, - num_nodes) + edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes) return edge_index, edge_attr diff --git a/torch_geometric/transforms/two_hop.py b/torch_geometric/transforms/two_hop.py index 86937cbdf97d..16ff5abb86f2 100644 --- a/torch_geometric/transforms/two_hop.py +++ b/torch_geometric/transforms/two_hop.py @@ -3,6 +3,7 @@ from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform +from torch_geometric.typing import SparseTensor from torch_geometric.utils import coalesce, remove_self_loops @@ -11,25 +12,25 @@ class TwoHop(BaseTransform): r"""Adds the two hop edges to the edge indices (functional name: :obj:`two_hop`).""" def __call__(self, data: Data) -> Data: - from torch_sparse import spspmm - edge_index, edge_attr = data.edge_index, data.edge_attr N = data.num_nodes - value = edge_index.new_ones((edge_index.size(1), ), dtype=torch.float) + adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(N, N)) - index, value = spspmm(edge_index, value, edge_index, value, N, N, N) - value.fill_(0) - index, value = remove_self_loops(index, value) + adj = adj @ adj + row, col, _ = adj.coo() + edge_index2 = torch.stack([row, col], dim=0) + edge_index2, _ = remove_self_loops(edge_index2) - edge_index = torch.cat([edge_index, index], dim=1) + edge_index = torch.cat([edge_index, edge_index2], dim=1) if edge_attr is None: data.edge_index = coalesce(edge_index, num_nodes=N) else: - value = value.view(-1, *[1 for _ in range(edge_attr.dim() - 1)]) - value = value.expand(-1, *list(edge_attr.size())[1:]) - edge_attr = torch.cat([edge_attr, value], dim=0) - data.edge_index, edge_attr = coalesce(edge_index, edge_attr, N) - data.edge_attr = edge_attr + # We treat newly added edge features as "zero-features": + edge_attr2 = edge_attr.new_zeros(edge_index2.size(1), + *edge_attr.size()[1:]) + edge_attr = torch.cat([edge_attr, edge_attr2], dim=0) + edge_index, edge_attr = coalesce(edge_index, edge_attr, N) + data.edge_index, data.edge_attr = edge_index, edge_attr return data