Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Roadmap] PyTorch SparseTensor Support 🚀 #5867

Open
51 of 68 tasks
EdisonLeeeee opened this issue Nov 1, 2022 · 16 comments
Open
51 of 68 tasks

[Roadmap] PyTorch SparseTensor Support 🚀 #5867

EdisonLeeeee opened this issue Nov 1, 2022 · 16 comments

Comments

@EdisonLeeeee
Copy link
Contributor

EdisonLeeeee commented Nov 1, 2022

🚀 The feature, motivation and pitch

PyG currently accepts torch.LongTensor: edge_index and torch_sparse.SparseTensor: adj in MessagePassing, which limits its flexibility for users who use native PyTorch SparseTensor. As such, an additional step is required for them to convert one to another.

The goal of this roadmap is to track the integration of native PyTorch SparseTensor support into PyG. After integration, the edge_index in MessagePassing can also accept PyTorch SparseTensor while allowing backpropagation like torch_sparse.SparseTensor.

General Roadmap

  • Add PyTorch SparseTensor support for torch_geometric.nn.MessagePassing and corresponding layers inherited from it
  • Add PyTorch SparseTensor support for torch_geometric.transforms
  • Add PyTorch SparseTensor support for torch_geometric.loader
  • Replace torch_sparse.SparseTensor based implementations with PyTorch functionality
  • ...

Implementations

Utility functions

torch_geometric.nn.*

torch_geometric.transforms.*

@EdisonLeeeee
Copy link
Contributor Author

EdisonLeeeee commented Nov 2, 2022

So far I can think of a way to support torch SparseTensor with the least effort in MessagePassing is:

def hook(model, input):
    edge_index, size, kwargs = input
    if is_torch_coo_tensor(edge_index):
        adj = edge_index
        edge_index = adj._indices()
        kwargs['edge_weight'] = adj._values()
    return edge_index, size, kwargs

register_propagate_forward_pre_hook(hook)

we can register such a hook in __init__ of MessagePassing, WDYT? @rusty1s

(Update) This is not a perfect solution as we cannot enjoy the benefits of PyTorch SparseTensor in this way.

@rusty1s
Copy link
Member

rusty1s commented Nov 3, 2022

I think this would only be a temporary solution. Ultimately, we want direct support for it similar to SparseTensor. Wondering what are the main challenges to support this?

@EdisonLeeeee
Copy link
Contributor Author

IMO, there are some challenges to support this:

  1. It only supports sum aggregation when using PyTorch SparseTensor. For other cases including advanced aggregations, we still need to convert it to the form of (edge_index, edge_weight).
  2. PyTorch SparseTensor does not support spspmm or mul with broadcasting, which makes it challenging to implement some operations such as gcn_norm (e.g., $A\cdot D^{-\frac{1}{2}}$ ) efficiently.
  3. As mentioned above, since PyTorch SparseTensor is also an instance of torch.Tensor, we should change the condition isinstance(edge_index, Tensor) if we want to support it. There should be some auxiliary functions like is_edge_index, is_torch_coo_tensor, or something else.

Please correct me if something is missing.

@rusty1s
Copy link
Member

rusty1s commented Nov 4, 2022

These are good thoughts. Appreciate it. Let me think.

  1. The only way I can think of supporting this is to provide a spmm function in PyG that supports both torch_sparse and PyTorch. For PyTorch, we can error out in case aggr != "sum" or aggr != "mean".
  2. Yes, this is a real problem. We would need to provide our own implementation for this by working on the indices and values directly. Happy to move this to follow-up PRs and work on general MessagePassing integration first
  3. Yes, this is expected, but IMO not a blocker, right?

@EdisonLeeeee
Copy link
Contributor Author

You are right. Thanks for making it clear. Will update the roadmap correspondingly and make some PRs soon :)

@EdisonLeeeee
Copy link
Contributor Author

Just go back from vacation :) I've updated the roadmap and corresponding TODO list. I plan to support them in the following weeks.

@rusty1s
Copy link
Member

rusty1s commented Jan 26, 2023

Nice to have you back :)

@sandeep-189
Copy link

sandeep-189 commented Mar 6, 2023

Hi, I was trying to use pytorch sparse tensor for a RGCN like model that I was working on. I noticed that the mean aggregation for the pytorch sparse tensor implementation in the spmm function is raising a NotImplementedError. Should the implementation for this be torch.sparse.mm(src, other)/src.shape[0] ? Am I misunderstanding mean aggregation?

@rusty1s
Copy link
Member

rusty1s commented Mar 7, 2023

Mean aggregation would refer to a row-wise mean that only normalized across non-zero values. I don't think this is super hard to integrate though @EdisonLeeeee.

@EdisonLeeeee
Copy link
Contributor Author

Yeah. Will take a look and make a PR for it.

@sandeep-189
Copy link

Thank you very much for the clarification. I think with that PR, RGCNConv and FastRGCNConv would also be working.

@EdisonLeeeee
Copy link
Contributor Author

I've made a PR for it: #6868

@romankouz
Copy link

romankouz commented Sep 26, 2023

image

image

Apologies if I should have opened a new issue for this, but I opened #6889 right around the time the influx of SparseTensor support started and have some follow up questions.

Is the support for SparseTensor simply so that these objects can accept a SparseTensor , or so that they use something like spmm for matrix multiplication? As far as I can see GMMConv still does not support sparse multiplication even if you pass in a sparse tensor. The commit you have for GMMConv SparseTensor support adds tests that allows you to pass SparseTensor but I don't think it does anything differently between sparse and dense tensors. Is this a correct assessment?

I ask not because I'm judging, but because I don't know if there is something I'm missing and just wanted to check :).

For further clarity, I have included a screenshot of the message method of GMMConv and proof that it still gives me out of memory issues and is not using sparse multiplication. Happy to provide anything else that could be useful :)

@rusty1s
Copy link
Member

rusty1s commented Sep 28, 2023

Hi @romankouz,

Note that this issue is more about bringing support of torch.sparse.tensor to PyG, not the SparseTensor class introduced by torch-sparse a while ago.

Overall, we aim to support as many layers as possible to accept sparse tensors as input, while only a fraction of them can actually make use of the spmm code path. As a general rule of thumb, we can only leverage spmm if messages are not dependent on destination node features and no edge features are involved (which is not the case for GMMConv). The other main benefit you get from using sparse tensors is deterministic behavior since the aggregation kernels no longer rely on atomic operations.

@romankouz
Copy link

Hi @rusty1s ,

Thank you for the follow up! I'm not entirely sure how aggregation kernels rely on atomic operations for dense matrices but not for sparse tensors. However, I do realize my misunderstanding regarding SparseTensor integration. Thank you for the response!

@rusty1s
Copy link
Member

rusty1s commented Sep 29, 2023

The main difference between the two representations is that for SparseTensor it is guaranteed that neighborhoods are grouped contiguous in memory, which lets us utilize segment_reduce rather than scatter_reduce from torch-sparse. As such, this lets us leverage an alternative aggregation in which we no longer need to rely on atomic operations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants