Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Mar 27, 2023
1 parent 56b8c59 commit ff187d7
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 13 deletions.
16 changes: 3 additions & 13 deletions torch_geometric/data/separate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch_geometric.data.data import BaseData
from torch_geometric.data.storage import BaseStorage
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import narrow


def separate(cls, batch: BaseData, idx: int, slice_dict: Any,
Expand Down Expand Up @@ -56,29 +57,18 @@ def _separate(
decrement: bool,
) -> Any:

if isinstance(value, Tensor) and not value.is_sparse:
if isinstance(value, Tensor):
# Narrow a `torch.Tensor` based on `slices`.
# NOTE: We need to take care of decrementing elements appropriately.
key = str(key)
cat_dim = batch.__cat_dim__(key, value, store)
start, end = int(slices[idx]), int(slices[idx + 1])
value = value.narrow(cat_dim or 0, start, end - start)
value = narrow(value, cat_dim or 0, start, end - start)
value = value.squeeze(0) if cat_dim is None else value
if decrement and (incs.dim() > 1 or int(incs[idx]) != 0):
value = value - incs[idx].to(value.device)
return value

elif isinstance(value, Tensor) and value.is_sparse:
# Allows to unbatch a sparse tensors from pytorch, including `sparse_coo_tensor`
key = str(key)
cat_dim = batch.__cat_dim__(key, value, store)
start, end = int(slices[idx]), int(slices[idx + 1])
indices = arange(start, end, dtype=long)
value = value.index_select(cat_dim or 0, indices)
if decrement and (incs.dim() > 1 or int(incs[idx]) != 0):
value = value - incs[idx].to(value.device)
return value

elif isinstance(value, SparseTensor) and decrement:
# Narrow a `SparseTensor` based on `slices`.
# NOTE: `cat_dim` may return a tuple to allow for diagonal stacking.
Expand Down
6 changes: 6 additions & 0 deletions torch_geometric/utils/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import Tensor

from torch_geometric.utils.mask import mask_select
from torch_geometric.utils.sparse import is_torch_sparse_tensor


def select(src: Union[Tensor, List[Any]], index_or_mask: Tensor,
Expand Down Expand Up @@ -41,6 +42,11 @@ def narrow(src: Union[Tensor, List[Any]], dim: int, start: int,
start (int): The starting dimension.
length (int): The distance to the ending dimension.
"""
if is_torch_sparse_tensor(src):
# TODO Sparse tensors in `torch.sparse` do not yet support `narrow`.
index = torch.arange(start, start + length, device=src.device)
return src.index_select(dim, index)

if isinstance(src, Tensor):
return src.narrow(dim, start, length)

Expand Down

0 comments on commit ff187d7

Please sign in to comment.