Skip to content

Commit

Permalink
Merge branch 'master' into pchmiel/aggr
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Jun 7, 2023
2 parents 0daaeae + 3e207dc commit dc65933
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torch_geometric/io/planetoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def read_planetoid_data(folder, prefix):
# as zero vectors to `tx` and `ty`.
len_test_indices = (test_index.max() - test_index.min()).item() + 1

tx_ext = torch.zeros(len_test_indices, tx.size(1))
tx_ext = torch.zeros(len_test_indices, tx.size(1), dtype=tx.dtype)
tx_ext[sorted_test_index - test_index.min(), :] = tx
ty_ext = torch.zeros(len_test_indices, ty.size(1))
ty_ext = torch.zeros(len_test_indices, ty.size(1), dtype=ty.dtype)
ty_ext[sorted_test_index - test_index.min(), :] = ty

tx, ty = tx_ext, ty_ext
Expand Down
11 changes: 11 additions & 0 deletions torch_geometric/nn/pool/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Optional
from torch import Tensor

Expand Down Expand Up @@ -161,6 +162,11 @@ def knn_graph(
:rtype: :class:`torch.Tensor`
"""
if batch is not None and x.device != batch.device:
warnings.warn("Input tensor 'x' and 'batch' are on different devices "
"in 'knn_graph'. Performing blocking device transfer")
batch = batch.to(x.device)

if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:
return torch_cluster.knn_graph(x, k, batch, loop, flow, cosine,
num_workers)
Expand Down Expand Up @@ -264,6 +270,11 @@ def radius_graph(
:rtype: :class:`torch.Tensor`
"""
if batch is not None and x.device != batch.device:
warnings.warn("Input tensor 'x' and 'batch' are on different devices "
"in 'radius_graph'. Performing blocking device transfer")
batch = batch.to(x.device)

if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:
return torch_cluster.radius_graph(x, r, batch, loop, max_num_neighbors,
flow, num_workers)
Expand Down

0 comments on commit dc65933

Please sign in to comment.