Skip to content

Commit

Permalink
[Code Coverage] loader/utils.py (#7857)
Browse files Browse the repository at this point in the history
Part of #6528.

---------

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
akihironitta and rusty1s authored Aug 10, 2023
1 parent b965d76 commit e95cdaf
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
16 changes: 16 additions & 0 deletions test/loader/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest
import torch

from torch_geometric.loader.utils import index_select


def test_index_select():
x = torch.randn(3, 5)
index = torch.tensor([0, 2])
assert torch.equal(index_select(x, index), x[index])
assert torch.equal(index_select(x, index, dim=-1), x[..., index])


def test_index_select_out_of_range():
with pytest.raises(IndexError, match="out of range"):
index_select(torch.randn(3, 5), torch.tensor([0, 2, 3]))
27 changes: 23 additions & 4 deletions torch_geometric/loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,29 @@
)


def index_select(value: FeatureTensorType, index: Tensor,
dim: int = 0) -> Tensor:

# PyTorch currently only supports indexing via `torch.int64` :(
def index_select(
value: FeatureTensorType,
index: Tensor,
dim: int = 0,
) -> Tensor:
r"""Indexes the :obj:`value` tensor along dimension :obj:`dim` using the
entries in :obj:`index`.
Args:
value (torch.Tensor or np.ndarray): The input tensor.
index (torch.Tensor): The 1-D tensor containing the indices to index.
dim (int, optional): The dimension in which to index.
(default: :obj:`0`)
.. warning::
:obj:`index` is casted to a :obj:`torch.int64` tensor internally, as
`PyTorch currently only supports indexing
<https://github.com/pytorch/pytorch/issues/61819>`_ via
:obj:`torch.int64`.
"""
# PyTorch currently only supports indexing via `torch.int64`:
# https://github.com/pytorch/pytorch/issues/61819
index = index.to(torch.int64)

if isinstance(value, Tensor):
Expand Down

0 comments on commit e95cdaf

Please sign in to comment.