Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch.get_example fails when using sparse_coo_tensor for the tensor values #7022

Closed
DomInvivo opened this issue Mar 23, 2023 · 5 comments · Fixed by #7037
Closed

Batch.get_example fails when using sparse_coo_tensor for the tensor values #7022

DomInvivo opened this issue Mar 23, 2023 · 5 comments · Fixed by #7037
Labels

Comments

@DomInvivo
Copy link
Contributor

🐛 Describe the bug

I cannot call batch[0] or batch.get_example(0) when using torch.sparse_coo_tensor, despite the batching working as expected.

# Imports
import torch
import torch_geometric
from torch_geometric.data import Data, Batch

# Create some fake graphs with sparse_coo_tensors, and batch them
data1 = Data(x=torch.sparse_coo_tensor(torch.tensor([[0, 1, 1], [1, 0, 1]]), torch.tensor([1, 2, 3]), (2, 2)))
data2 = Data(x=torch.sparse_coo_tensor(torch.tensor([[0, 1, 2], [1, 0, 1]]), torch.tensor([4, 5, 6]), (3, 2)))
batch = Batch.from_data_list([data1, data2])

# Print the stuff
print("batch.x = ", batch.x) # WORKS
print(batch[0]) # FAILS
print(batch.get_example(0)) # FAILS

--> 
batch.x =  tensor(indices=tensor([[0, 1, 1, 2, 3, 4],
                       [1, 0, 1, 1, 0, 1]]),
       values=tensor([1, 2, 3, 4, 5, 6]),
       size=(5, 2), nnz=6, layout=torch.sparse_coo)

On ubuntu, I get this error:

File ~/.venv/goli_ipu/lib/python3.8/site-packages/torch_geometric/data/batch.py:154, in Batch.__getitem__(self, idx)
    150 def __getitem__(self, idx: Union[int, np.integer, str, IndexType]) -> Any:
    151     if (isinstance(idx, (int, np.integer))
    152             or (isinstance(idx, Tensor) and idx.dim() == 0)
    153             or (isinstance(idx, np.ndarray) and np.isscalar(idx))):
--> 154         return self.get_example(idx)
    155     elif isinstance(idx, str) or (isinstance(idx, tuple)
    156                                   and isinstance(idx[0], str)):
    157         # Accessing attributes or node/edge types:
    158         return super().__getitem__(idx)

File ~/.venv/goli_ipu/lib/python3.8/site-packages/torch_geometric/data/batch.py:103, in Batch.get_example(self, idx)
     98 if not hasattr(self, '_slice_dict'):
     99     raise RuntimeError(
    100         ("Cannot reconstruct 'Data' object from 'Batch' because "
    101          "'Batch' was not created via 'Batch.from_data_list()'"))
--> 103 data = separate(
    104     cls=self.__class__.__bases__[-1],
    105     batch=self,
    106     idx=idx,
    107     slice_dict=self._slice_dict,
    108     inc_dict=self._inc_dict,
    109     decrement=True,
    110 )
    112 return data

File ~/.venv/goli_ipu/lib/python3.8/site-packages/torch_geometric/data/separate.py:37, in separate(cls, batch, idx, slice_dict, inc_dict, decrement)
     35         slices = slice_dict[attr]
     36         incs = inc_dict[attr] if decrement else None
---> 37     data_store[attr] = _separate(attr, batch_store[attr], idx, slices,
     38                                  incs, batch, batch_store, decrement)
     40 # The `num_nodes` attribute needs special treatment, as we cannot infer
     41 # the real number of nodes from the total number of nodes alone:
     42 if hasattr(batch_store, '_num_nodes'):

File ~/.venv/goli_ipu/lib/python3.8/site-packages/torch_geometric/data/separate.py:65, in _separate(key, value, idx, slices, incs, batch, store, decrement)
     63 cat_dim = batch.__cat_dim__(key, value, store)
     64 start, end = int(slices[idx]), int(slices[idx + 1])
---> 65 value = value.narrow(cat_dim or 0, start, end - start)
     66 value = value.squeeze(0) if cat_dim is None else value
     67 if decrement and (incs.dim() > 1 or int(incs[idx]) != 0):

NotImplementedError: Could not run 'aten::as_strided' with arguments from the 'SparseCPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::as_strided' is only available for these backends: [CPU, Meta, QuantizedCPU, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMeta, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PythonDispatcher].

CPU: registered at aten/src/ATen/RegisterCPU.cpp:30798 [kernel]
Meta: registered at aten/src/ATen/RegisterMeta.cpp:26815 [kernel]
QuantizedCPU: registered at aten/src/ATen/RegisterQuantizedCPU.cpp:929 [kernel]
BackendSelect: fallthrough registered at ../aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:140 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:488 [backend fallback]
Functionalize: registered at aten/src/ATen/RegisterFunctionalization_0.cpp:19962 [kernel]
Named: fallthrough registered at ../aten/src/ATen/core/NamedRegistrations.cpp:11 [kernel]
Conjugate: fallthrough registered at ../aten/src/ATen/ConjugateFallback.cpp:22 [kernel]
Negative: fallthrough registered at ../aten/src/ATen/native/NegateFallback.cpp:22 [kernel]
ZeroTensor: registered at aten/src/ATen/RegisterZeroTensor.cpp:161 [kernel]
ADInplaceOrView: registered at ../torch/csrc/autograd/generated/ADInplaceOrViewType_0.cpp:4822 [kernel]
AutogradOther: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:14904 [autograd kernel]
AutogradCPU: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:14904 [autograd kernel]
AutogradCUDA: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:14904 [autograd kernel]
AutogradHIP: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:14904 [autograd kernel]
AutogradXLA: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:14904 [autograd kernel]
AutogradMPS: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:14904 [autograd kernel]
AutogradIPU: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:14904 [autograd kernel]
AutogradXPU: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:14904 [autograd kernel]
AutogradHPU: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:14904 [autograd kernel]
AutogradVE: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:14904 [autograd kernel]
AutogradLazy: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:14904 [autograd kernel]
AutogradMeta: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:14904 [autograd kernel]
AutogradPrivateUse1: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:14904 [autograd kernel]
AutogradPrivateUse2: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:14904 [autograd kernel]
AutogradPrivateUse3: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:14904 [autograd kernel]
AutogradNestedTensor: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:14904 [autograd kernel]
Tracer: registered at ../torch/csrc/autograd/generated/TraceType_0.cpp:16458 [kernel]
AutocastCPU: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:482 [backend fallback]
AutocastCUDA: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:324 [backend fallback]
FuncTorchBatched: registered at ../aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:747 [kernel]
FuncTorchVmapMode: fallthrough registered at ../aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
Batched: registered at ../aten/src/ATen/BatchingRegistrations.cpp:1068 [kernel]
VmapMode: fallthrough registered at ../aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at ../aten/src/ATen/functorch/TensorWrapper.cpp:189 [backend fallback]
PythonTLSSnapshot: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:148 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:484 [backend fallback]
PythonDispatcher: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:144 [backend fallback]

On windows, I get this error:

    150 def __getitem__(self, idx: Union[int, np.integer, str, IndexType]) -> Any:
    151     if (isinstance(idx, (int, np.integer))
    152             or (isinstance(idx, Tensor) and idx.dim() == 0)
    153             or (isinstance(idx, np.ndarray) and np.isscalar(idx))):
--> 154         return self.get_example(idx)
    155     elif isinstance(idx, str) or (isinstance(idx, tuple)
    156                                   and isinstance(idx[0], str)):
    157         # Accessing attributes or node/edge types:
    158         return super().__getitem__(idx)

File ~\miniconda3\envs\goli\lib\site-packages\torch_geometric\data\batch.py:103, in Batch.get_example(self, idx)
     98 if not hasattr(self, '_slice_dict'):
     99     raise RuntimeError(
    100         ("Cannot reconstruct 'Data' object from 'Batch' because "
    101          "'Batch' was not created via 'Batch.from_data_list()'"))
--> 103 data = separate(
    104     cls=self.__class__.__bases__[-1],
    105     batch=self,
    106     idx=idx,
    107     slice_dict=self._slice_dict,
    108     inc_dict=self._inc_dict,
    109     decrement=True,
    110 )
    112 return data

File ~\miniconda3\envs\goli\lib\site-packages\torch_geometric\data\separate.py:37, in separate(cls, batch, idx, slice_dict, inc_dict, decrement)
     35         slices = slice_dict[attr]
     36         incs = inc_dict[attr] if decrement else None
---> 37     data_store[attr] = _separate(attr, batch_store[attr], idx, slices,
     38                                  incs, batch, batch_store, decrement)
     40 # The `num_nodes` attribute needs special treatment, as we cannot infer
     41 # the real number of nodes from the total number of nodes alone:
     42 if hasattr(batch_store, '_num_nodes'):

File ~\miniconda3\envs\goli\lib\site-packages\torch_geometric\data\separate.py:65, in _separate(key, value, idx, slices, incs, batch, store, decrement)
     63 cat_dim = batch.__cat_dim__(key, value, store)
     64 start, end = int(slices[idx]), int(slices[idx + 1])
---> 65 value = value.narrow(cat_dim or 0, start, end - start)
     66 value = value.squeeze(0) if cat_dim is None else value
     67 if decrement and (incs.dim() > 1 or int(incs[idx]) != 0):

RuntimeError: Tensors of type SparseTensorImpl do not have strides```

### Environment

* PyG version: 2.3.0.dev20230306, 2.2.0 on Windows
* PyTorch version: 1.13.0+cpu, 1.12.1 on Windows
* OS: Ubuntu 20.04.5 LTS, and Windows 11
* Python version: 3.8.10 on Ubuntu, 3.10.9 on Windows
* CUDA/cuDNN version: Fails on CPU. I haven't tried on GPU.
* How you installed PyTorch and PyG (`conda`, `pip`, source): `pip` on Ubuntu, `conda` on Windows
* Any other relevant information (*e.g.*, version of `torch-scatter`):
@rusty1s
Copy link
Member

rusty1s commented Mar 24, 2023

Interesting. It looks like torch.sparse_coo_tensor does not support slicing/narrowing yet, so we would either need to implement our own logic for this or wait for PyTorch to fix this on their end. Might be good to make a corresponding issue on PyTorch side.

@DomInvivo
Copy link
Contributor Author

Yeah you are right. The code can be simplified to this:

import torch
sp = torch.sparse_coo_tensor(torch.tensor([[0, 1, 1], [1, 0, 1]]), torch.tensor([1, 2, 3]), (2, 2))
sp[:1] # CRASHES
sp[[0, 1]] # CRASHES

@DomInvivo
Copy link
Contributor Author

It seems that pytorch supports index_select for sparse_coo_tensor, so here's a fix:

Change this line:

if isinstance(value, Tensor):

to:

if isinstance(value, Tensor) and not value.is_sparse:

Add this other option to the same if/elif statements:

elif isinstance(value, Tensor) and value.is_sparse and (value.layout == sparse_coo):
    key = str(key)
    cat_dim = batch.__cat_dim__(key, value, store)
    start, end = int(slices[idx]), int(slices[idx + 1])
    indices = arange(start, end, dtype=long) # Converting the slice to indices to use `index_select`
    value = value.index_select(cat_dim or 0, indices)
    if decrement and (incs.dim() > 1 or int(incs[idx]) != 0):
        value = value - incs[idx].to(value.device)
    return value

And of course change the imports from torch import Tensor, sparse_coo, arange, long here:

from torch import Tensor

@rusty1s
Copy link
Member

rusty1s commented Mar 25, 2023

Cool. Are you interested in sending a PR in for this? :)

@DomInvivo
Copy link
Contributor Author

DomInvivo commented Mar 25, 2023

Done in PR#7037 :)

@rusty1s rusty1s linked a pull request Mar 27, 2023 that will close this issue
rusty1s added a commit that referenced this issue Mar 27, 2023
This PR allows to fix the issues with `Batch.from_data_list` or
`Batch.get_example` or `Batch__getitem__` when using pytorch's sparse
tensors, such as `sparse_coo_tensor`. See issue #7022

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants