Skip to content

Commit

Permalink
EdgeIndex.matmul: Use torch-sparse if available (pyg-team#8492)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Nov 30, 2023
1 parent 98af585 commit 5fc014e
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 20 deletions.
13 changes: 11 additions & 2 deletions test/data/test_edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
withCUDA,
withPackage,
)
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import scatter


Expand Down Expand Up @@ -362,6 +363,16 @@ def test_matmul_backward():
assert torch.allclose(x1.grad, x2.grad)


@withPackage('torch_sparse')
def test_to_sparse_tensor():
adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]]).to_sparse_tensor()
assert isinstance(adj, SparseTensor)
assert adj.sizes() == [3, 3]
row, col, _ = adj.coo()
assert torch.equal(row, torch.tensor([0, 1, 1, 2]))
assert torch.equal(col, torch.tensor([1, 0, 2, 1]))


def test_save_and_load(tmp_path):
adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row')
adj.fill_cache_()
Expand Down Expand Up @@ -473,8 +484,6 @@ def forward(self, x: Tensor, edge_index: EdgeIndex) -> Tensor:
if __name__ == '__main__':
import argparse

from torch_sparse import SparseTensor

warnings.filterwarnings('ignore', ".*Sparse CSR tensor support.*")

parser = argparse.ArgumentParser()
Expand Down
85 changes: 67 additions & 18 deletions torch_geometric/data/edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch import Tensor

import torch_geometric.typing
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import index_sort

HANDLED_FUNCTIONS: Dict[Callable, Callable] = {}
Expand Down Expand Up @@ -366,6 +367,30 @@ def sort_by(

return torch.return_types.sort([out, perm])

def to_sparse_tensor(
self,
value: Optional[Tensor] = None,
) -> SparseTensor:
r"""Converts the :class:`EdgeIndex` representation to a
:class:`torch_sparse.SparseTensor`. Requires that :obj:`torch-sparse`
is installed.
Args:
value (torch.Tensor, optional): The values of non-zero indices.
(default: :obj:`None`)
"""
is_sorted = self._sort_order == SortOrder.ROW

return SparseTensor(
row=self[0],
col=self[1],
rowptr=self.get_rowptr() if is_sorted else None,
value=value,
sparse_sizes=self.get_sparse_size(),
is_sorted=is_sorted,
trust_data=True,
)

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
# `EdgeIndex` should be treated as a regular PyTorch tensor for all
Expand Down Expand Up @@ -710,13 +735,29 @@ def forward(
reduce: ReduceType = 'sum',
) -> Tensor:

assert reduce in ['sum']
if input._sort_order != SortOrder.ROW:
raise ValueError("Sparse-dense matrix multiplication requires "
"'EdgeIndex' to be sorted by rows")

if reduce not in ReduceType.__args__:
raise NotImplementedError("`reduce='{reduce}'` not yet supported")

if other.requires_grad:
ctx.save_for_backward(input, input_value)

other = other.detach()
if input_value is not None:
input_value = input_value.detach()

if torch_geometric.typing.WITH_TORCH_SPARSE:
# If `torch-sparse` is available, it still provides a faster
# sparse-dense matmul code path (after all these years...):
rowptr, col = input.get_rowptr(), input[1]
return torch.ops.torch_sparse.spmm_sum( #
None, rowptr, col, input_value, None, None, other)

input_value = input.get_value() if input_value is None else input_value
adj = to_sparse_csr(input, input_value) # Requires `sort_order="row"`.
adj = to_sparse_csr(input, input_value)

return adj @ other

Expand All @@ -732,28 +773,35 @@ def backward(

# We need to compute `adj.t() @ out_grad`. For the transpose, we
# first sort by column and then create a CSR matrix from it.
# We can call `input.flip(0)` here and create a sparse CSR matrix
# from it, but creating it via `colptr` directly is more efficient:
# Note that the sort result is cached across multiple applications.
input, perm = input.sort_by(SortOrder.COL)
colptr, row = input.get_colptr(), input[0]
if input_value is not None:
input_value = input_value[perm]
input_value = input_value.detach()[perm]

if torch_geometric.typing.WITH_TORCH_CLUSTER:
other_grad = torch.ops.torch_sparse.spmm_sum( #
None, colptr, row, input_value, None, None, out_grad)

else:
input_value = input.get_value()

# We can call `input.flip(0)` here and call `to_sparse_csr`, but
# creating the sparse CSR matrix from `colptr` directly is a bit
# more efficient:
adj_t = torch.sparse_csr_tensor(
crow_indices=input.get_colptr(),
col_indices=input[0],
values=input_value,
size=input.get_sparse_size()[::-1],
device=input.device,
)
if input_value is None:
input_value = input.get_value()

adj_t = torch.sparse_csr_tensor(
crow_indices=input.get_colptr(),
col_indices=input[0],
values=input_value,
size=input.get_sparse_size()[::-1],
device=input.device,
)

other_grad = adj_t @ out_grad
other_grad = adj_t @ out_grad

if ctx.needs_input_grad[2]:
raise NotImplementedError
raise NotImplementedError("Gradient computation for 'input_value' "
"not yet supported")

return None, other_grad, None, None

Expand All @@ -768,7 +816,8 @@ def matmul(
reduce: ReduceType = 'sum',
) -> Union[Tensor, Tuple[EdgeIndex, Tensor]]:

assert reduce in ['sum']
if reduce not in ReduceType.__args__:
raise NotImplementedError("`reduce='{reduce}'` not yet supported")

if not isinstance(other, EdgeIndex):
if other_value is not None:
Expand Down

0 comments on commit 5fc014e

Please sign in to comment.