diff --git a/.github/workflows/latest_testing.yml b/.github/workflows/latest_testing.yml index 1df55913f049..aed4ca6b2195 100644 --- a/.github/workflows/latest_testing.yml +++ b/.github/workflows/latest_testing.yml @@ -62,6 +62,6 @@ jobs: pytest test/profile/ pytest test/sampler/ pytest test/testing/ - # pytest test/transforms/ + pytest test/transforms/ pytest test/utils/ pytest test/visualization/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 073ba04f34f3..6b35aaf046e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Change `torch_sparse.SparseTensor` logic to utilize `torch.sparse_csr` instead ([#7041](https://github.com/pyg-team/pytorch_geometric/pull/7041)) + ### Removed ## [2.3.0] - 2023-03-23 diff --git a/test/transforms/test_add_metapaths.py b/test/transforms/test_add_metapaths.py index a3f2403bbc44..e16bd4ae322a 100644 --- a/test/transforms/test_add_metapaths.py +++ b/test/transforms/test_add_metapaths.py @@ -4,6 +4,7 @@ from torch import tensor from torch_geometric.data import HeteroData +from torch_geometric.testing import withPackage from torch_geometric.transforms import AddMetaPaths, AddRandomMetaPaths from torch_geometric.utils import coalesce @@ -21,6 +22,7 @@ def generate_data() -> HeteroData: return data +@withPackage('torch_sparse') def test_add_metapaths(): data = generate_data() # Test transform options: @@ -74,6 +76,7 @@ def test_add_metapaths(): assert list(meta.metapath_dict.keys()) == new_edge_types +@withPackage('torch_sparse') def test_add_metapaths_max_sample(): torch.manual_seed(12345) @@ -86,6 +89,7 @@ def test_add_metapaths_max_sample(): assert meta['metapath_0'].edge_index.size(1) < 9 +@withPackage('torch_sparse') def test_add_weighted_metapaths(): torch.manual_seed(12345) @@ -144,6 +148,7 @@ def test_add_weighted_metapaths(): assert edge_weight.tolist() == [1, 2, 2, 4] +@withPackage('torch_sparse') def test_add_random_metapaths(): data = generate_data() diff --git a/test/transforms/test_gcn_norm.py b/test/transforms/test_gcn_norm.py index 9eb2bd5debc5..cd92ad2a17b5 100644 --- a/test/transforms/test_gcn_norm.py +++ b/test/transforms/test_gcn_norm.py @@ -1,14 +1,14 @@ import torch -from torch_sparse import SparseTensor +import torch_geometric.typing from torch_geometric.data import Data from torch_geometric.transforms import GCNNorm +from torch_geometric.typing import SparseTensor def test_gcn_norm(): edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) edge_weight = torch.ones(edge_index.size(1)) - adj_t = SparseTensor.from_edge_index(edge_index, edge_weight).t() transform = GCNNorm() assert str(transform) == 'GCNNorm(add_self_loops=True)' @@ -32,14 +32,16 @@ def test_gcn_norm(): assert torch.allclose(data.edge_weight, expected_edge_weight, atol=1e-4) # For `SparseTensor`, expected outputs will be sorted: - expected_edge_index = [[0, 0, 1, 1, 1, 2, 2], [0, 1, 0, 1, 2, 1, 2]] - expected_edge_weight = torch.tensor( - [0.500, 0.4082, 0.4082, 0.3333, 0.4082, 0.4082, 0.5000]) - - data = Data(adj_t=adj_t) - data = transform(data) - assert len(data) == 1 - row, col, value = data.adj_t.coo() - assert row.tolist() == expected_edge_index[0] - assert col.tolist() == expected_edge_index[1] - assert torch.allclose(value, expected_edge_weight, atol=1e-4) + if torch_geometric.typing.WITH_TORCH_SPARSE: + expected_edge_index = [[0, 0, 1, 1, 1, 2, 2], [0, 1, 0, 1, 2, 1, 2]] + expected_edge_weight = torch.tensor( + [0.500, 0.4082, 0.4082, 0.3333, 0.4082, 0.4082, 0.5000]) + + adj_t = SparseTensor.from_edge_index(edge_index, edge_weight).t() + data = Data(adj_t=adj_t) + data = transform(data) + assert len(data) == 1 + row, col, value = data.adj_t.coo() + assert row.tolist() == expected_edge_index[0] + assert col.tolist() == expected_edge_index[1] + assert torch.allclose(value, expected_edge_weight, atol=1e-4) diff --git a/test/transforms/test_to_sparse_tensor.py b/test/transforms/test_to_sparse_tensor.py index e0ba20b13fc2..b2c85016f785 100644 --- a/test/transforms/test_to_sparse_tensor.py +++ b/test/transforms/test_to_sparse_tensor.py @@ -1,6 +1,7 @@ import pytest import torch +import torch_geometric.typing from torch_geometric.data import Data, HeteroData from torch_geometric.transforms import ToSparseTensor @@ -26,14 +27,14 @@ def test_to_sparse_tensor_basic(layout): assert torch.equal(data.edge_attr, edge_attr[perm]) assert 'adj_t' in data - if layout is None: # `torch_sparse.SparseTensor`. + if layout is None and torch_geometric.typing.WITH_TORCH_SPARSE: row, col, value = data.adj_t.coo() assert row.tolist() == [0, 1, 1, 2] assert col.tolist() == [1, 0, 2, 1] assert torch.equal(value, edge_weight[perm]) else: adj_t = data.adj_t - assert adj_t.layout == layout + assert adj_t.layout == layout or torch.sparse_csr if layout != torch.sparse_coo: adj_t = adj_t.to_sparse_coo() assert adj_t.indices().tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] @@ -69,7 +70,7 @@ def test_hetero_to_sparse_tensor(layout): data = ToSparseTensor(layout=layout)(data) - if layout is None: # `torch_sparse.SparseTensor`. + if layout is None and torch_geometric.typing.WITH_TORCH_SPARSE: row, col, value = data['v', 'v'].adj_t.coo() assert row.tolist() == [0, 1, 1, 2] assert col.tolist() == [1, 0, 2, 1] @@ -81,14 +82,14 @@ def test_hetero_to_sparse_tensor(layout): assert value is None else: adj_t = data['v', 'v'].adj_t - assert adj_t.layout == layout + assert adj_t.layout == layout or torch.sparse_csr if layout != torch.sparse_coo: adj_t = adj_t.to_sparse_coo() assert adj_t.indices().tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] assert adj_t.values().tolist() == [1., 1., 1., 1.] adj_t = data['v', 'w'].adj_t - assert adj_t.layout == layout + assert adj_t.layout == layout or torch.sparse_csr if layout != torch.sparse_coo: adj_t = adj_t.to_sparse_coo() assert adj_t.indices().tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] diff --git a/torch_geometric/transforms/add_positional_encoding.py b/torch_geometric/transforms/add_positional_encoding.py index 6229f0d7f527..3fecc169b3e9 100644 --- a/torch_geometric/transforms/add_positional_encoding.py +++ b/torch_geometric/transforms/add_positional_encoding.py @@ -6,11 +6,13 @@ from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform -from torch_geometric.typing import SparseTensor from torch_geometric.utils import ( get_laplacian, get_self_loop_attr, + scatter, + to_edge_index, to_scipy_sparse_matrix, + to_torch_csr_tensor, ) @@ -116,24 +118,22 @@ def __init__( self.attr_name = attr_name def __call__(self, data: Data) -> Data: - num_nodes = data.num_nodes - edge_index, edge_weight = data.edge_index, data.edge_weight + row, col = data.edge_index + N = data.num_nodes - adj = SparseTensor.from_edge_index(edge_index, edge_weight, - sparse_sizes=(num_nodes, num_nodes)) + value = data.edge_weight + if value is None: + value = torch.ones(data.num_edges, device=row.device) + value = scatter(value, row, dim_size=N, reduce='sum').clamp(min=1)[row] + value = 1.0 / value - # Compute D^{-1} A: - deg_inv = 1.0 / adj.sum(dim=1) - deg_inv[deg_inv == float('inf')] = 0 - adj = adj * deg_inv.view(-1, 1) + adj = to_torch_csr_tensor(data.edge_index, value, size=data.size()) out = adj - row, col, value = out.coo() - pe_list = [get_self_loop_attr((row, col), value, num_nodes)] + pe_list = [get_self_loop_attr(*to_edge_index(out), num_nodes=N)] for _ in range(self.walk_length - 1): out = out @ adj - row, col, value = out.coo() - pe_list.append(get_self_loop_attr((row, col), value, num_nodes)) + pe_list.append(get_self_loop_attr(*to_edge_index(out), N)) pe = torch.stack(pe_list, dim=-1) data = add_node_attr(data, pe, attr_name=self.attr_name) diff --git a/torch_geometric/transforms/feature_propagation.py b/torch_geometric/transforms/feature_propagation.py index 426fcdf0ebf2..5a9f89b6c1e7 100644 --- a/torch_geometric/transforms/feature_propagation.py +++ b/torch_geometric/transforms/feature_propagation.py @@ -4,7 +4,7 @@ from torch_geometric.data.datapipes import functional_transform from torch_geometric.nn.conv.gcn_conv import gcn_norm from torch_geometric.transforms import BaseTransform -from torch_geometric.typing import SparseTensor +from torch_geometric.utils import is_torch_sparse_tensor, to_torch_csc_tensor @functional_transform('feature_propagation') @@ -52,14 +52,13 @@ def __call__(self, data: Data) -> Data: if 'edge_weight' in data: edge_weight = data.edge_weight edge_index = data.edge_index - adj_t = SparseTensor(row=edge_index[1], col=edge_index[0], - value=edge_weight, - sparse_sizes=data.size()[::-1], - is_sorted=False, trust_data=True) + adj_t = to_torch_csc_tensor(edge_index, edge_weight, + size=data.size()).t() + adj_t, _ = gcn_norm(adj_t, add_self_loops=False) + elif is_torch_sparse_tensor(data.adj_t): + adj_t, _ = gcn_norm(data.adj_t, add_self_loops=False) else: - adj_t = data.adj_t - - adj_t = gcn_norm(adj_t, add_self_loops=False) + adj_t = gcn_norm(data.adj_t, add_self_loops=False) x = data.x.clone() x[missing_mask] = 0. diff --git a/torch_geometric/transforms/rooted_subgraph.py b/torch_geometric/transforms/rooted_subgraph.py index 5f91fac1e47f..fde7e2eb4f4f 100644 --- a/torch_geometric/transforms/rooted_subgraph.py +++ b/torch_geometric/transforms/rooted_subgraph.py @@ -7,7 +7,7 @@ from torch_geometric.data import Data from torch_geometric.transforms import BaseTransform -from torch_geometric.typing import SparseTensor +from torch_geometric.utils import to_torch_csc_tensor class RootedSubgraphData(Data): @@ -116,11 +116,7 @@ def extract( data: Data, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: - adj_t = SparseTensor.from_edge_index( - data.edge_index, - sparse_sizes=(data.num_nodes, data.num_nodes), - ).t() - + adj_t = to_torch_csc_tensor(data.edge_index, size=data.size()).t() n_mask = torch.eye(data.num_nodes, device=data.edge_index.device) for _ in range(self.num_hops): n_mask += adj_t @ n_mask diff --git a/torch_geometric/transforms/sign.py b/torch_geometric/transforms/sign.py index 819d6eaad8c9..c9e0d9e0b6be 100644 --- a/torch_geometric/transforms/sign.py +++ b/torch_geometric/transforms/sign.py @@ -3,7 +3,7 @@ from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform -from torch_geometric.typing import SparseTensor +from torch_geometric.utils import scatter, to_torch_csc_tensor @functional_transform('sign') @@ -37,13 +37,18 @@ def __init__(self, K: int): def __call__(self, data: Data) -> Data: assert data.edge_index is not None row, col = data.edge_index - adj_t = SparseTensor(row=col, col=row, - sparse_sizes=(data.num_nodes, data.num_nodes)) - - deg = adj_t.sum(dim=1).to(torch.float) - deg_inv_sqrt = deg.pow(-0.5) - deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 - adj_t = deg_inv_sqrt.view(-1, 1) * adj_t * deg_inv_sqrt.view(1, -1) + N = data.num_nodes + + edge_weight = data.edge_weight + if edge_weight is None: + edge_weight = torch.ones(data.num_edges, device=row.device) + + deg = scatter(edge_weight, col, dim_size=N, reduce='sum') + deg_inv_sqrt = deg.pow_(-0.5) + deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) + edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] + adj = to_torch_csc_tensor(data.edge_index, edge_weight, size=(N, N)) + adj_t = adj.t() assert data.x is not None xs = [data.x] diff --git a/torch_geometric/transforms/two_hop.py b/torch_geometric/transforms/two_hop.py index a05792ec1beb..10f08ad5b37f 100644 --- a/torch_geometric/transforms/two_hop.py +++ b/torch_geometric/transforms/two_hop.py @@ -3,8 +3,12 @@ from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform -from torch_geometric.typing import SparseTensor -from torch_geometric.utils import coalesce, remove_self_loops +from torch_geometric.utils import ( + coalesce, + remove_self_loops, + to_edge_index, + to_torch_csr_tensor, +) @functional_transform('two_hop') @@ -15,11 +19,8 @@ def __call__(self, data: Data) -> Data: edge_index, edge_attr = data.edge_index, data.edge_attr N = data.num_nodes - adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(N, N)) - - adj = adj @ adj - row, col, _ = adj.coo() - edge_index2 = torch.stack([row, col], dim=0) + adj = to_torch_csr_tensor(edge_index, size=(N, N)) + edge_index2, _ = to_edge_index(adj @ adj) edge_index2, _ = remove_self_loops(edge_index2) edge_index = torch.cat([edge_index, edge_index2], dim=1)