Skip to content

Commit

Permalink
Drop torch_sparse dependency in tests (1/n) (#7041)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Mar 26, 2023
1 parent c78c5b2 commit 0ea43bb
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 61 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/latest_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,6 @@ jobs:
pytest test/profile/
pytest test/sampler/
pytest test/testing/
# pytest test/transforms/
pytest test/transforms/
pytest test/utils/
pytest test/visualization/
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Change `torch_sparse.SparseTensor` logic to utilize `torch.sparse_csr` instead ([#7041](https://github.com/pyg-team/pytorch_geometric/pull/7041))

### Removed

## [2.3.0] - 2023-03-23
Expand Down
5 changes: 5 additions & 0 deletions test/transforms/test_add_metapaths.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import tensor

from torch_geometric.data import HeteroData
from torch_geometric.testing import withPackage
from torch_geometric.transforms import AddMetaPaths, AddRandomMetaPaths
from torch_geometric.utils import coalesce

Expand All @@ -21,6 +22,7 @@ def generate_data() -> HeteroData:
return data


@withPackage('torch_sparse')
def test_add_metapaths():
data = generate_data()
# Test transform options:
Expand Down Expand Up @@ -74,6 +76,7 @@ def test_add_metapaths():
assert list(meta.metapath_dict.keys()) == new_edge_types


@withPackage('torch_sparse')
def test_add_metapaths_max_sample():
torch.manual_seed(12345)

Expand All @@ -86,6 +89,7 @@ def test_add_metapaths_max_sample():
assert meta['metapath_0'].edge_index.size(1) < 9


@withPackage('torch_sparse')
def test_add_weighted_metapaths():
torch.manual_seed(12345)

Expand Down Expand Up @@ -144,6 +148,7 @@ def test_add_weighted_metapaths():
assert edge_weight.tolist() == [1, 2, 2, 4]


@withPackage('torch_sparse')
def test_add_random_metapaths():
data = generate_data()

Expand Down
28 changes: 15 additions & 13 deletions test/transforms/test_gcn_norm.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import torch
from torch_sparse import SparseTensor

import torch_geometric.typing
from torch_geometric.data import Data
from torch_geometric.transforms import GCNNorm
from torch_geometric.typing import SparseTensor


def test_gcn_norm():
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
edge_weight = torch.ones(edge_index.size(1))
adj_t = SparseTensor.from_edge_index(edge_index, edge_weight).t()

transform = GCNNorm()
assert str(transform) == 'GCNNorm(add_self_loops=True)'
Expand All @@ -32,14 +32,16 @@ def test_gcn_norm():
assert torch.allclose(data.edge_weight, expected_edge_weight, atol=1e-4)

# For `SparseTensor`, expected outputs will be sorted:
expected_edge_index = [[0, 0, 1, 1, 1, 2, 2], [0, 1, 0, 1, 2, 1, 2]]
expected_edge_weight = torch.tensor(
[0.500, 0.4082, 0.4082, 0.3333, 0.4082, 0.4082, 0.5000])

data = Data(adj_t=adj_t)
data = transform(data)
assert len(data) == 1
row, col, value = data.adj_t.coo()
assert row.tolist() == expected_edge_index[0]
assert col.tolist() == expected_edge_index[1]
assert torch.allclose(value, expected_edge_weight, atol=1e-4)
if torch_geometric.typing.WITH_TORCH_SPARSE:
expected_edge_index = [[0, 0, 1, 1, 1, 2, 2], [0, 1, 0, 1, 2, 1, 2]]
expected_edge_weight = torch.tensor(
[0.500, 0.4082, 0.4082, 0.3333, 0.4082, 0.4082, 0.5000])

adj_t = SparseTensor.from_edge_index(edge_index, edge_weight).t()
data = Data(adj_t=adj_t)
data = transform(data)
assert len(data) == 1
row, col, value = data.adj_t.coo()
assert row.tolist() == expected_edge_index[0]
assert col.tolist() == expected_edge_index[1]
assert torch.allclose(value, expected_edge_weight, atol=1e-4)
11 changes: 6 additions & 5 deletions test/transforms/test_to_sparse_tensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch

import torch_geometric.typing
from torch_geometric.data import Data, HeteroData
from torch_geometric.transforms import ToSparseTensor

Expand All @@ -26,14 +27,14 @@ def test_to_sparse_tensor_basic(layout):
assert torch.equal(data.edge_attr, edge_attr[perm])
assert 'adj_t' in data

if layout is None: # `torch_sparse.SparseTensor`.
if layout is None and torch_geometric.typing.WITH_TORCH_SPARSE:
row, col, value = data.adj_t.coo()
assert row.tolist() == [0, 1, 1, 2]
assert col.tolist() == [1, 0, 2, 1]
assert torch.equal(value, edge_weight[perm])
else:
adj_t = data.adj_t
assert adj_t.layout == layout
assert adj_t.layout == layout or torch.sparse_csr
if layout != torch.sparse_coo:
adj_t = adj_t.to_sparse_coo()
assert adj_t.indices().tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]
Expand Down Expand Up @@ -69,7 +70,7 @@ def test_hetero_to_sparse_tensor(layout):

data = ToSparseTensor(layout=layout)(data)

if layout is None: # `torch_sparse.SparseTensor`.
if layout is None and torch_geometric.typing.WITH_TORCH_SPARSE:
row, col, value = data['v', 'v'].adj_t.coo()
assert row.tolist() == [0, 1, 1, 2]
assert col.tolist() == [1, 0, 2, 1]
Expand All @@ -81,14 +82,14 @@ def test_hetero_to_sparse_tensor(layout):
assert value is None
else:
adj_t = data['v', 'v'].adj_t
assert adj_t.layout == layout
assert adj_t.layout == layout or torch.sparse_csr
if layout != torch.sparse_coo:
adj_t = adj_t.to_sparse_coo()
assert adj_t.indices().tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]
assert adj_t.values().tolist() == [1., 1., 1., 1.]

adj_t = data['v', 'w'].adj_t
assert adj_t.layout == layout
assert adj_t.layout == layout or torch.sparse_csr
if layout != torch.sparse_coo:
adj_t = adj_t.to_sparse_coo()
assert adj_t.indices().tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]
Expand Down
26 changes: 13 additions & 13 deletions torch_geometric/transforms/add_positional_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import (
get_laplacian,
get_self_loop_attr,
scatter,
to_edge_index,
to_scipy_sparse_matrix,
to_torch_csr_tensor,
)


Expand Down Expand Up @@ -116,24 +118,22 @@ def __init__(
self.attr_name = attr_name

def __call__(self, data: Data) -> Data:
num_nodes = data.num_nodes
edge_index, edge_weight = data.edge_index, data.edge_weight
row, col = data.edge_index
N = data.num_nodes

adj = SparseTensor.from_edge_index(edge_index, edge_weight,
sparse_sizes=(num_nodes, num_nodes))
value = data.edge_weight
if value is None:
value = torch.ones(data.num_edges, device=row.device)
value = scatter(value, row, dim_size=N, reduce='sum').clamp(min=1)[row]
value = 1.0 / value

# Compute D^{-1} A:
deg_inv = 1.0 / adj.sum(dim=1)
deg_inv[deg_inv == float('inf')] = 0
adj = adj * deg_inv.view(-1, 1)
adj = to_torch_csr_tensor(data.edge_index, value, size=data.size())

out = adj
row, col, value = out.coo()
pe_list = [get_self_loop_attr((row, col), value, num_nodes)]
pe_list = [get_self_loop_attr(*to_edge_index(out), num_nodes=N)]
for _ in range(self.walk_length - 1):
out = out @ adj
row, col, value = out.coo()
pe_list.append(get_self_loop_attr((row, col), value, num_nodes))
pe_list.append(get_self_loop_attr(*to_edge_index(out), N))
pe = torch.stack(pe_list, dim=-1)

data = add_node_attr(data, pe, attr_name=self.attr_name)
Expand Down
15 changes: 7 additions & 8 deletions torch_geometric/transforms/feature_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.transforms import BaseTransform
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import is_torch_sparse_tensor, to_torch_csc_tensor


@functional_transform('feature_propagation')
Expand Down Expand Up @@ -52,14 +52,13 @@ def __call__(self, data: Data) -> Data:
if 'edge_weight' in data:
edge_weight = data.edge_weight
edge_index = data.edge_index
adj_t = SparseTensor(row=edge_index[1], col=edge_index[0],
value=edge_weight,
sparse_sizes=data.size()[::-1],
is_sorted=False, trust_data=True)
adj_t = to_torch_csc_tensor(edge_index, edge_weight,
size=data.size()).t()
adj_t, _ = gcn_norm(adj_t, add_self_loops=False)
elif is_torch_sparse_tensor(data.adj_t):
adj_t, _ = gcn_norm(data.adj_t, add_self_loops=False)
else:
adj_t = data.adj_t

adj_t = gcn_norm(adj_t, add_self_loops=False)
adj_t = gcn_norm(data.adj_t, add_self_loops=False)

x = data.x.clone()
x[missing_mask] = 0.
Expand Down
8 changes: 2 additions & 6 deletions torch_geometric/transforms/rooted_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from torch_geometric.data import Data
from torch_geometric.transforms import BaseTransform
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import to_torch_csc_tensor


class RootedSubgraphData(Data):
Expand Down Expand Up @@ -116,11 +116,7 @@ def extract(
data: Data,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:

adj_t = SparseTensor.from_edge_index(
data.edge_index,
sparse_sizes=(data.num_nodes, data.num_nodes),
).t()

adj_t = to_torch_csc_tensor(data.edge_index, size=data.size()).t()
n_mask = torch.eye(data.num_nodes, device=data.edge_index.device)
for _ in range(self.num_hops):
n_mask += adj_t @ n_mask
Expand Down
21 changes: 13 additions & 8 deletions torch_geometric/transforms/sign.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import scatter, to_torch_csc_tensor


@functional_transform('sign')
Expand Down Expand Up @@ -37,13 +37,18 @@ def __init__(self, K: int):
def __call__(self, data: Data) -> Data:
assert data.edge_index is not None
row, col = data.edge_index
adj_t = SparseTensor(row=col, col=row,
sparse_sizes=(data.num_nodes, data.num_nodes))

deg = adj_t.sum(dim=1).to(torch.float)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
adj_t = deg_inv_sqrt.view(-1, 1) * adj_t * deg_inv_sqrt.view(1, -1)
N = data.num_nodes

edge_weight = data.edge_weight
if edge_weight is None:
edge_weight = torch.ones(data.num_edges, device=row.device)

deg = scatter(edge_weight, col, dim_size=N, reduce='sum')
deg_inv_sqrt = deg.pow_(-0.5)
deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
adj = to_torch_csc_tensor(data.edge_index, edge_weight, size=(N, N))
adj_t = adj.t()

assert data.x is not None
xs = [data.x]
Expand Down
15 changes: 8 additions & 7 deletions torch_geometric/transforms/two_hop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import coalesce, remove_self_loops
from torch_geometric.utils import (
coalesce,
remove_self_loops,
to_edge_index,
to_torch_csr_tensor,
)


@functional_transform('two_hop')
Expand All @@ -15,11 +19,8 @@ def __call__(self, data: Data) -> Data:
edge_index, edge_attr = data.edge_index, data.edge_attr
N = data.num_nodes

adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(N, N))

adj = adj @ adj
row, col, _ = adj.coo()
edge_index2 = torch.stack([row, col], dim=0)
adj = to_torch_csr_tensor(edge_index, size=(N, N))
edge_index2, _ = to_edge_index(adj @ adj)
edge_index2, _ = remove_self_loops(edge_index2)

edge_index = torch.cat([edge_index, edge_index2], dim=1)
Expand Down

0 comments on commit 0ea43bb

Please sign in to comment.