Skip to content

Commit

Permalink
Make torch-sparse dependency optional (part 4) (#6628)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Feb 7, 2023
1 parent ef2d7be commit bb827d8
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 15 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

- `torch-sparse` is now an optional dependency ([#6625](https://github.com/pyg-team/pytorch_geometric/pull/6625), [#6626](https://github.com/pyg-team/pytorch_geometric/pull/6626), [#6627](https://github.com/pyg-team/pytorch_geometric/pull/6627))
- `torch-sparse` is now an optional dependency ([#6625](https://github.com/pyg-team/pytorch_geometric/pull/6625), [#6626](https://github.com/pyg-team/pytorch_geometric/pull/6626), [#6627](https://github.com/pyg-team/pytorch_geometric/pull/6627), [#6628](https://github.com/pyg-team/pytorch_geometric/pull/6628))
- Removed most of the `torch-scatter` dependencies ([#6394](https://github.com/pyg-team/pytorch_geometric/pull/6394), [#6395](https://github.com/pyg-team/pytorch_geometric/pull/6395), [#6399](https://github.com/pyg-team/pytorch_geometric/pull/6399), [#6400](https://github.com/pyg-team/pytorch_geometric/pull/6400), [#6615](https://github.com/pyg-team/pytorch_geometric/pull/6615), [#6617](https://github.com/pyg-team/pytorch_geometric/pull/6617))
- Removed the deprecated classes `GNNExplainer` and `Explainer` from `nn.models` ([#6382](https://github.com/pyg-team/pytorch_geometric/pull/6382))
- Removed `target_index` argument in the `Explainer` interface ([#6270](https://github.com/pyg-team/pytorch_geometric/pull/6270))
Expand Down
6 changes: 2 additions & 4 deletions torch_geometric/loader/shadow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import Tensor

from torch_geometric.data import Batch, Data
from torch_geometric.typing import SparseTensor
from torch_geometric.typing import WITH_TORCH_SPARSE, SparseTensor


class ShaDowKHopSampler(torch.utils.data.DataLoader):
Expand Down Expand Up @@ -40,9 +40,7 @@ def __init__(self, data: Data, depth: int, num_neighbors: int,
node_idx: Optional[Tensor] = None, replace: bool = False,
**kwargs):

try:
import torch_sparse # noqa
except ImportError:
if not WITH_TORCH_SPARSE:
raise ImportError(
f"'{self.__class__.__name__}' requires 'torch-sparse'")

Expand Down
3 changes: 1 addition & 2 deletions torch_geometric/nn/conv/message_passing.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ from torch_geometric.typing import *

import torch
from torch import Tensor
import torch_sparse
from torch_geometric.typing import SparseTensor
from torch_geometric.typing import SparseTensor, torch_sparse
from torch_geometric.nn.conv.message_passing import *
from {{module}} import *

Expand Down
11 changes: 7 additions & 4 deletions torch_geometric/sampler/hgt_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
NodeSamplerInput,
)
from torch_geometric.sampler.utils import remap_keys, to_hetero_csc
from torch_geometric.typing import EdgeType, NodeType, OptTensor
from torch_geometric.typing import (
WITH_TORCH_SPARSE,
EdgeType,
NodeType,
OptTensor,
)


class HGTSampler(BaseSampler):
Expand All @@ -22,9 +27,7 @@ def __init__(
is_sorted: bool = False,
share_memory: bool = False,
):
try:
import torch_sparse # noqa
except ImportError:
if not WITH_TORCH_SPARSE:
raise ImportError(
f"'{self.__class__.__name__}' requires 'torch-sparse'")

Expand Down
16 changes: 12 additions & 4 deletions torch_geometric/sampler/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,13 @@ def _sample(
batch = {k: v[0] for k, v in node.items()}
node = {k: v[1] for k, v in node.items()}

else:
elif torch_geometric.typing.WITH_TORCH_SPARSE:
if self.disjoint:
raise ValueError("'disjoint' sampling not supported for "
"neighbor sampling via 'torch-sparse'. "
"Please install 'pyg-lib' for improved "
"and optimized sampling routines.")
import torch_sparse # noqa

out = torch.ops.torch_sparse.hetero_neighbor_sample(
self.node_types,
self.edge_types,
Expand All @@ -240,6 +240,10 @@ def _sample(
)
node, row, col, edge, batch = out + (None, )

else:
raise ImportError(f"'{self.__class__.__name__}' requires "
f"either 'pyg-lib' or 'torch-sparse'")

return HeteroSamplerOutput(
node=node,
row=remap_keys(row, self.to_edge_type),
Expand Down Expand Up @@ -270,13 +274,13 @@ def _sample(
if self.disjoint:
batch, node = node.t().contiguous()

else:
elif torch_geometric.typing.WITH_TORCH_SPARSE:
if self.disjoint:
raise ValueError("'disjoint' sampling not supported for "
"neighbor sampling via 'torch-sparse'. "
"Please install 'pyg-lib' for improved "
"and optimized sampling routines.")
import torch_sparse # noqa

out = torch.ops.torch_sparse.neighbor_sample(
self.colptr,
self.row,
Expand All @@ -287,6 +291,10 @@ def _sample(
)
node, row, col, edge, batch = out + (None, )

else:
raise ImportError(f"'{self.__class__.__name__}' requires "
f"either 'pyg-lib' or 'torch-sparse'")

return SamplerOutput(
node=node,
row=row,
Expand Down

0 comments on commit bb827d8

Please sign in to comment.