From bb827d821288949f8e8971534592a90c954cdee2 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Tue, 7 Feb 2023 10:39:49 +0100 Subject: [PATCH] Make `torch-sparse` dependency optional (part 4) (#6628) --- CHANGELOG.md | 2 +- torch_geometric/loader/shadow.py | 6 ++---- torch_geometric/nn/conv/message_passing.jinja | 3 +-- torch_geometric/sampler/hgt_sampler.py | 11 +++++++---- torch_geometric/sampler/neighbor_sampler.py | 16 ++++++++++++---- 5 files changed, 23 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 45d299bf2c6d..5a733bb6dd0d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -113,7 +113,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed -- `torch-sparse` is now an optional dependency ([#6625](https://github.com/pyg-team/pytorch_geometric/pull/6625), [#6626](https://github.com/pyg-team/pytorch_geometric/pull/6626), [#6627](https://github.com/pyg-team/pytorch_geometric/pull/6627)) +- `torch-sparse` is now an optional dependency ([#6625](https://github.com/pyg-team/pytorch_geometric/pull/6625), [#6626](https://github.com/pyg-team/pytorch_geometric/pull/6626), [#6627](https://github.com/pyg-team/pytorch_geometric/pull/6627), [#6628](https://github.com/pyg-team/pytorch_geometric/pull/6628)) - Removed most of the `torch-scatter` dependencies ([#6394](https://github.com/pyg-team/pytorch_geometric/pull/6394), [#6395](https://github.com/pyg-team/pytorch_geometric/pull/6395), [#6399](https://github.com/pyg-team/pytorch_geometric/pull/6399), [#6400](https://github.com/pyg-team/pytorch_geometric/pull/6400), [#6615](https://github.com/pyg-team/pytorch_geometric/pull/6615), [#6617](https://github.com/pyg-team/pytorch_geometric/pull/6617)) - Removed the deprecated classes `GNNExplainer` and `Explainer` from `nn.models` ([#6382](https://github.com/pyg-team/pytorch_geometric/pull/6382)) - Removed `target_index` argument in the `Explainer` interface ([#6270](https://github.com/pyg-team/pytorch_geometric/pull/6270)) diff --git a/torch_geometric/loader/shadow.py b/torch_geometric/loader/shadow.py index da934247b8de..c8bfb9d975ec 100644 --- a/torch_geometric/loader/shadow.py +++ b/torch_geometric/loader/shadow.py @@ -5,7 +5,7 @@ from torch import Tensor from torch_geometric.data import Batch, Data -from torch_geometric.typing import SparseTensor +from torch_geometric.typing import WITH_TORCH_SPARSE, SparseTensor class ShaDowKHopSampler(torch.utils.data.DataLoader): @@ -40,9 +40,7 @@ def __init__(self, data: Data, depth: int, num_neighbors: int, node_idx: Optional[Tensor] = None, replace: bool = False, **kwargs): - try: - import torch_sparse # noqa - except ImportError: + if not WITH_TORCH_SPARSE: raise ImportError( f"'{self.__class__.__name__}' requires 'torch-sparse'") diff --git a/torch_geometric/nn/conv/message_passing.jinja b/torch_geometric/nn/conv/message_passing.jinja index 79b210e6e90b..afd5af9ee5a5 100644 --- a/torch_geometric/nn/conv/message_passing.jinja +++ b/torch_geometric/nn/conv/message_passing.jinja @@ -4,8 +4,7 @@ from torch_geometric.typing import * import torch from torch import Tensor -import torch_sparse -from torch_geometric.typing import SparseTensor +from torch_geometric.typing import SparseTensor, torch_sparse from torch_geometric.nn.conv.message_passing import * from {{module}} import * diff --git a/torch_geometric/sampler/hgt_sampler.py b/torch_geometric/sampler/hgt_sampler.py index 5f98d78d3183..e9048937d430 100644 --- a/torch_geometric/sampler/hgt_sampler.py +++ b/torch_geometric/sampler/hgt_sampler.py @@ -9,7 +9,12 @@ NodeSamplerInput, ) from torch_geometric.sampler.utils import remap_keys, to_hetero_csc -from torch_geometric.typing import EdgeType, NodeType, OptTensor +from torch_geometric.typing import ( + WITH_TORCH_SPARSE, + EdgeType, + NodeType, + OptTensor, +) class HGTSampler(BaseSampler): @@ -22,9 +27,7 @@ def __init__( is_sorted: bool = False, share_memory: bool = False, ): - try: - import torch_sparse # noqa - except ImportError: + if not WITH_TORCH_SPARSE: raise ImportError( f"'{self.__class__.__name__}' requires 'torch-sparse'") diff --git a/torch_geometric/sampler/neighbor_sampler.py b/torch_geometric/sampler/neighbor_sampler.py index e1de887f33d4..522ff2fb965a 100644 --- a/torch_geometric/sampler/neighbor_sampler.py +++ b/torch_geometric/sampler/neighbor_sampler.py @@ -220,13 +220,13 @@ def _sample( batch = {k: v[0] for k, v in node.items()} node = {k: v[1] for k, v in node.items()} - else: + elif torch_geometric.typing.WITH_TORCH_SPARSE: if self.disjoint: raise ValueError("'disjoint' sampling not supported for " "neighbor sampling via 'torch-sparse'. " "Please install 'pyg-lib' for improved " "and optimized sampling routines.") - import torch_sparse # noqa + out = torch.ops.torch_sparse.hetero_neighbor_sample( self.node_types, self.edge_types, @@ -240,6 +240,10 @@ def _sample( ) node, row, col, edge, batch = out + (None, ) + else: + raise ImportError(f"'{self.__class__.__name__}' requires " + f"either 'pyg-lib' or 'torch-sparse'") + return HeteroSamplerOutput( node=node, row=remap_keys(row, self.to_edge_type), @@ -270,13 +274,13 @@ def _sample( if self.disjoint: batch, node = node.t().contiguous() - else: + elif torch_geometric.typing.WITH_TORCH_SPARSE: if self.disjoint: raise ValueError("'disjoint' sampling not supported for " "neighbor sampling via 'torch-sparse'. " "Please install 'pyg-lib' for improved " "and optimized sampling routines.") - import torch_sparse # noqa + out = torch.ops.torch_sparse.neighbor_sample( self.colptr, self.row, @@ -287,6 +291,10 @@ def _sample( ) node, row, col, edge, batch = out + (None, ) + else: + raise ImportError(f"'{self.__class__.__name__}' requires " + f"either 'pyg-lib' or 'torch-sparse'") + return SamplerOutput( node=node, row=row,