Skip to content

Commit

Permalink
Make torch-sparse dependency optional (part 3) (#6627)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Feb 7, 2023
1 parent 447bfa1 commit ef2d7be
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 33 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

- `torch-sparse` is now an optional dependency ([#6625](https://github.com/pyg-team/pytorch_geometric/pull/6625), [#6626](https://github.com/pyg-team/pytorch_geometric/pull/6626))
- `torch-sparse` is now an optional dependency ([#6625](https://github.com/pyg-team/pytorch_geometric/pull/6625), [#6626](https://github.com/pyg-team/pytorch_geometric/pull/6626), [#6627](https://github.com/pyg-team/pytorch_geometric/pull/6627))
- Removed most of the `torch-scatter` dependencies ([#6394](https://github.com/pyg-team/pytorch_geometric/pull/6394), [#6395](https://github.com/pyg-team/pytorch_geometric/pull/6395), [#6399](https://github.com/pyg-team/pytorch_geometric/pull/6399), [#6400](https://github.com/pyg-team/pytorch_geometric/pull/6400), [#6615](https://github.com/pyg-team/pytorch_geometric/pull/6615), [#6617](https://github.com/pyg-team/pytorch_geometric/pull/6617))
- Removed the deprecated classes `GNNExplainer` and `Explainer` from `nn.models` ([#6382](https://github.com/pyg-team/pytorch_geometric/pull/6382))
- Removed `target_index` argument in the `Explainer` interface ([#6270](https://github.com/pyg-team/pytorch_geometric/pull/6270))
Expand Down
4 changes: 2 additions & 2 deletions docs/source/advanced/sparse_tensor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ With it, the :class:`~torch_geometric.nn.conv.GINConv` layer can now be implemen

.. code-block:: python
from torch_sparse import matmul
import torch_sparse
class GINConv(MessagePassing):
def __init__(self):
Expand All @@ -127,7 +127,7 @@ With it, the :class:`~torch_geometric.nn.conv.GINConv` layer can now be implemen
return x_j
def message_and_aggregate(self, adj_t, x):
return matmul(adj_t, x, reduce=self.aggr)
return torch_sparse.matmul(adj_t, x, reduce=self.aggr)
Playing around with the new :class:`SparseTensor` format is straightforward since all of our GNNs work with it out-of-the-box.
To convert the :obj:`edge_index` format to the newly introduced :class:`SparseTensor` format, you can make use of the :class:`torch_geometric.transforms.ToSparseTensor` transform:
Expand Down
19 changes: 7 additions & 12 deletions torch_geometric/nn/models/graph_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,10 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch_sparse import spspmm

from torch_geometric.nn import GCNConv, TopKPooling
from torch_geometric.typing import OptTensor, PairTensor
from torch_geometric.utils import (
add_self_loops,
remove_self_loops,
sort_edge_index,
)
from torch_geometric.typing import OptTensor, PairTensor, SparseTensor
from torch_geometric.utils import add_self_loops, remove_self_loops
from torch_geometric.utils.repeat import repeat


Expand Down Expand Up @@ -131,11 +126,11 @@ def augment_adj(self, edge_index: Tensor, edge_weight: Tensor,
edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
num_nodes=num_nodes)
edge_index, edge_weight = sort_edge_index(edge_index, edge_weight,
num_nodes)
edge_index, edge_weight = spspmm(edge_index, edge_weight, edge_index,
edge_weight, num_nodes, num_nodes,
num_nodes)
adj = SparseTensor.from_edge_index(edge_index, edge_weight,
sparse_sizes=(num_nodes, num_nodes))
adj = adj @ adj
row, col, edge_weight = adj.coo()
edge_index = torch.stack([row, col], dim=0)
edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
return edge_index, edge_weight

Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/nn/models/signed_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch_sparse import coalesce

from torch_geometric.nn import SignedConv
from torch_geometric.utils import (
coalesce,
negative_sampling,
structured_negative_sampling,
)
Expand Down Expand Up @@ -111,7 +111,7 @@ def create_spectral_features(
edge_index = torch.cat([edge_index, torch.stack([col, row])], dim=1)
val = torch.cat([val, val], dim=0)

edge_index, val = coalesce(edge_index, val, N, N)
edge_index, val = coalesce(edge_index, val, num_nodes=N)
val = val - 1

# Borrowed from:
Expand Down
6 changes: 2 additions & 4 deletions torch_geometric/nn/pool/pool.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from typing import Optional

import torch
from torch_sparse import coalesce

from torch_geometric.utils import remove_self_loops, scatter
from torch_geometric.utils import coalesce, remove_self_loops, scatter


def pool_edge(cluster, edge_index, edge_attr: Optional[torch.Tensor] = None):
num_nodes = cluster.size(0)
edge_index = cluster[edge_index.view(-1)].view(2, -1)
edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
if edge_index.numel() > 0:
edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes,
num_nodes)
edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes)
return edge_index, edge_attr


Expand Down
25 changes: 13 additions & 12 deletions torch_geometric/transforms/two_hop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import coalesce, remove_self_loops


Expand All @@ -11,25 +12,25 @@ class TwoHop(BaseTransform):
r"""Adds the two hop edges to the edge indices
(functional name: :obj:`two_hop`)."""
def __call__(self, data: Data) -> Data:
from torch_sparse import spspmm

edge_index, edge_attr = data.edge_index, data.edge_attr
N = data.num_nodes

value = edge_index.new_ones((edge_index.size(1), ), dtype=torch.float)
adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(N, N))

index, value = spspmm(edge_index, value, edge_index, value, N, N, N)
value.fill_(0)
index, value = remove_self_loops(index, value)
adj = adj @ adj
row, col, _ = adj.coo()
edge_index2 = torch.stack([row, col], dim=0)
edge_index2, _ = remove_self_loops(edge_index2)

edge_index = torch.cat([edge_index, index], dim=1)
edge_index = torch.cat([edge_index, edge_index2], dim=1)
if edge_attr is None:
data.edge_index = coalesce(edge_index, num_nodes=N)
else:
value = value.view(-1, *[1 for _ in range(edge_attr.dim() - 1)])
value = value.expand(-1, *list(edge_attr.size())[1:])
edge_attr = torch.cat([edge_attr, value], dim=0)
data.edge_index, edge_attr = coalesce(edge_index, edge_attr, N)
data.edge_attr = edge_attr
# We treat newly added edge features as "zero-features":
edge_attr2 = edge_attr.new_zeros(edge_index2.size(1),
*edge_attr.size()[1:])
edge_attr = torch.cat([edge_attr, edge_attr2], dim=0)
edge_index, edge_attr = coalesce(edge_index, edge_attr, N)
data.edge_index, data.edge_attr = edge_index, edge_attr

return data

0 comments on commit ef2d7be

Please sign in to comment.