Skip to content

Commit

Permalink
Accelerated sparse tensor conversions (#7042)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Mar 26, 2023
1 parent 0ea43bb commit 93f9f59
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Accelerated sparse tensor conversion routiens ([#7042](https://github.com/pyg-team/pytorch_geometric/pull/7042))
- Change `torch_sparse.SparseTensor` logic to utilize `torch.sparse_csr` instead ([#7041](https://github.com/pyg-team/pytorch_geometric/pull/7041))

### Removed
Expand Down
23 changes: 23 additions & 0 deletions test/utils/test_sparse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch

import torch_geometric.typing
from torch_geometric.profile import benchmark
from torch_geometric.testing import is_full_test
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import (
Expand Down Expand Up @@ -195,3 +196,25 @@ def test_to_edge_index():
edge_index, edge_attr = jit(adj)
assert edge_index.tolist() == [[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]
assert edge_attr.tolist() == [1., 1., 1., 1., 1., 1.]


if __name__ == '__main__':
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()

num_nodes, num_edges = 10_000, 200_000
edge_index = torch.randint(num_nodes, (2, num_edges), device=args.device)

benchmark(
funcs=[
SparseTensor.from_edge_index, to_torch_coo_tensor,
to_torch_csr_tensor, to_torch_csc_tensor
],
func_names=['SparseTensor', 'To COO', 'To CSR', 'To CSC'],
args=(edge_index, None, (num_nodes, num_nodes)),
num_steps=50 if args.device == 'cpu' else 500,
num_warmups=10 if args.device == 'cpu' else 100,
)
4 changes: 2 additions & 2 deletions torch_geometric/utils/num_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import torch
from torch import Tensor

import torch_geometric
from torch_geometric.typing import SparseTensor # noqa
from torch_geometric.utils.sparse import is_torch_sparse_tensor


@torch.jit._overload
Expand All @@ -24,7 +24,7 @@ def maybe_num_nodes(edge_index, num_nodes=None):
if num_nodes is not None:
return num_nodes
elif isinstance(edge_index, Tensor):
if is_torch_sparse_tensor(edge_index):
if torch_geometric.utils.is_torch_sparse_tensor(edge_index):
return max(edge_index.size(0), edge_index.size(1))
return int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0
else:
Expand Down
35 changes: 29 additions & 6 deletions torch_geometric/utils/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import Tensor

from torch_geometric.typing import SparseTensor
from torch_geometric.utils import coalesce


def dense_to_sparse(adj: Tensor) -> Tuple[Tensor, Tensor]:
Expand Down Expand Up @@ -122,18 +123,21 @@ def to_torch_coo_tensor(
if not isinstance(size, (tuple, list)):
size = (size, size)

if not is_coalesced:
edge_index, edge_attr = coalesce(edge_index, edge_attr, max(size))

if edge_attr is None:
edge_attr = torch.ones(edge_index.size(1), device=edge_index.device)

size = tuple(size) + edge_attr.size()[1:]

adj = torch.sparse_coo_tensor(
indices=edge_index,
values=edge_attr,
size=size,
size=tuple(size) + edge_attr.size()[1:],
device=edge_index.device,
)
return adj._coalesced_(True) if is_coalesced else adj.coalesce()
adj = adj._coalesced_(True)

return adj


def to_torch_csr_tensor(
Expand Down Expand Up @@ -212,8 +216,27 @@ def to_torch_csc_tensor(
size=(4, 4), nnz=6, layout=torch.sparse_csc)
"""
adj = to_torch_coo_tensor(edge_index, edge_attr, size, is_coalesced)
return adj.to_sparse_csc()
if size is None:
size = int(edge_index.max()) + 1
if not isinstance(size, (tuple, list)):
size = (size, size)

if not is_coalesced:
edge_index, edge_attr = coalesce(edge_index, edge_attr, max(size),
sort_by_row=False)

if edge_attr is None:
edge_attr = torch.ones(edge_index.size(1), device=edge_index.device)

adj = torch.sparse_csc_tensor(
ccol_indices=index2ptr(edge_index[1], size[1]),
row_indices=edge_index[0],
values=edge_attr,
size=tuple(size) + edge_attr.size()[1:],
device=edge_index.device,
)

return adj


def to_edge_index(adj: Union[Tensor, SparseTensor]) -> Tuple[Tensor, Tensor]:
Expand Down

0 comments on commit 93f9f59

Please sign in to comment.