-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Roadmap] PyTorch SparseTensor Support 🚀 #5867
Comments
So far I can think of a way to support torch SparseTensor with the least effort in def hook(model, input):
edge_index, size, kwargs = input
if is_torch_coo_tensor(edge_index):
adj = edge_index
edge_index = adj._indices()
kwargs['edge_weight'] = adj._values()
return edge_index, size, kwargs
register_propagate_forward_pre_hook(hook) we can register such a hook in (Update) This is not a perfect solution as we cannot enjoy the benefits of PyTorch SparseTensor in this way. |
I think this would only be a temporary solution. Ultimately, we want direct support for it similar to |
IMO, there are some challenges to support this:
Please correct me if something is missing. |
These are good thoughts. Appreciate it. Let me think.
|
You are right. Thanks for making it clear. Will update the roadmap correspondingly and make some PRs soon :) |
Just go back from vacation :) I've updated the roadmap and corresponding TODO list. I plan to support them in the following weeks. |
Nice to have you back :) |
Hi, I was trying to use pytorch sparse tensor for a RGCN like model that I was working on. I noticed that the mean aggregation for the pytorch sparse tensor implementation in the spmm function is raising a NotImplementedError. Should the implementation for this be torch.sparse.mm(src, other)/src.shape[0] ? Am I misunderstanding mean aggregation? |
Mean aggregation would refer to a row-wise mean that only normalized across non-zero values. I don't think this is super hard to integrate though @EdisonLeeeee. |
Yeah. Will take a look and make a PR for it. |
Thank you very much for the clarification. I think with that PR, RGCNConv and FastRGCNConv would also be working. |
I've made a PR for it: #6868 |
Apologies if I should have opened a new issue for this, but I opened #6889 right around the time the influx of Is the support for I ask not because I'm judging, but because I don't know if there is something I'm missing and just wanted to check :). For further clarity, I have included a screenshot of the |
Hi @romankouz, Note that this issue is more about bringing support of Overall, we aim to support as many layers as possible to accept sparse tensors as input, while only a fraction of them can actually make use of the |
Hi @rusty1s , Thank you for the follow up! I'm not entirely sure how aggregation kernels rely on atomic operations for dense matrices but not for sparse tensors. However, I do realize my misunderstanding regarding SparseTensor integration. Thank you for the response! |
The main difference between the two representations is that for |
🚀 The feature, motivation and pitch
PyG currently accepts
torch.LongTensor: edge_index
andtorch_sparse.SparseTensor: adj
inMessagePassing
, which limits its flexibility for users who use native PyTorch SparseTensor. As such, an additional step is required for them to convert one to another.The goal of this roadmap is to track the integration of native PyTorch SparseTensor support into PyG. After integration, the
edge_index
inMessagePassing
can also accept PyTorch SparseTensor while allowing backpropagation liketorch_sparse.SparseTensor
.General Roadmap
torch_geometric.nn.MessagePassing
and corresponding layers inherited from ittorch_geometric.transforms
torch_geometric.loader
torch_sparse.SparseTensor
based implementations with PyTorch functionalityImplementations
Utility functions
is_torch_sparse_tensor
function. Since bothstrided (dense)
orsparse_coo
tensors are instances oftorch.Tensor
, there should be a function to distinguish between two different inputs, i.e., LongTensoredge_index
and SparseTensoradj
(Addspmm
andis_torch_sparse_tensor
#5906).spmm
function: sparse-dense matrix multiplication supporting bothtorch_sparse
and PyTorch SparseTensor. For PyTorch SparseTensor, currently onlysum
andmean
aggregations are allowed (Addspmm
andis_torch_sparse_tensor
#5906, Addedmean
reduction toutils.spmm
for PyTorch Sparse Tensor #6868, PyTorch Sparse Tensor support:HANConv
,GATv2Conv
,HGTConv
,GMMConv
,GPSConv
, andRGATConv
#6932).is_sparse_tensor
: check for eithertorch.sparse.Tensor
ortorch_sparse.SparseTensor
(Addis_sparse
andto_torch_coo_tensor
#6003).to_torch_coo_tensor
: convertedge_index
totorch.sparse.Tensor
(coo format) (Addis_sparse
andto_torch_coo_tensor
#6003).torch.sparse.Tensor
input ofadd_self_loop
,remove_self_loop
, andmaybe_num_nodes
(Added PyTorch Sparse Tensor support forremove_self_loops
,add_self_loops
, andmaybe_num_nodes
#6847)torch_geometric.nn.*
MessagePassing
(Add PyTorch SparseTensor support forMessagePassing
#5944, Pytorch Sparse tensor support:AntiSymmetricConv
,CGConv
, andTransformerConv
#6633)GCNConv
(Add PyTorch SparseTensor support forGCNConv
andgcn_norm
#6033)AGNNConv
(Pytorch Sparse tensor support:ClusterGCN
,SAGEConv
,AGNNConv
,APPNP
, andFeaStConv
#6874)APPNP
(Pytorch Sparse tensor support:ClusterGCN
,SAGEConv
,AGNNConv
,APPNP
, andFeaStConv
#6874)AntiSymmetricConv
(Pytorch Sparse tensor support:AntiSymmetricConv
,CGConv
, andTransformerConv
#6633)ARMAConv
CGConv
(Pytorch Sparse tensor support:AntiSymmetricConv
,CGConv
, andTransformerConv
#6633)ChebConv
ClusterGCNConv
(Pytorch Sparse tensor support:ClusterGCN
,SAGEConv
,AGNNConv
,APPNP
, andFeaStConv
#6874)DNAConv
(Pytorch Sparse tensor support:DNAConv
,EdgeConv
,EGConv
, andFAConv
#6748)EdgeConv
(Pytorch Sparse tensor support:DNAConv
,EdgeConv
,EGConv
, andFAConv
#6748)EGConv
(Pytorch Sparse tensor support:DNAConv
,EdgeConv
,EGConv
, andFAConv
#6748)FAConv
(Pytorch Sparse tensor support:DNAConv
,EdgeConv
,EGConv
, andFAConv
#6748)FeaStConv
(Pytorch Sparse tensor support:ClusterGCN
,SAGEConv
,AGNNConv
,APPNP
, andFeaStConv
#6874)FiLMConv
FusedGATConv
GATConv
(Pytorch Sparse tensor support:GATConv
,GatedGraphConv
,GCN2Conv
, andGENConv
#6897)GatedGraphConv
(Pytorch Sparse tensor support:GATConv
,GatedGraphConv
,GCN2Conv
, andGENConv
#6897)GATv2Conv
(PyTorch Sparse Tensor support:HANConv
,GATv2Conv
,HGTConv
,GMMConv
,GPSConv
, andRGATConv
#6932)GCN2Conv
(Pytorch Sparse tensor support:GATConv
,GatedGraphConv
,GCN2Conv
, andGENConv
#6897)GENConv
(Pytorch Sparse tensor support:GATConv
,GatedGraphConv
,GCN2Conv
, andGENConv
#6897)GeneralConv
GINConv
(Add PyTorch SparseTensor support forGINConv
,SAGEConv
, andGraphConv
#6532)GMMConv
(PyTorch Sparse Tensor support:HANConv
,GATv2Conv
,HGTConv
,GMMConv
,GPSConv
, andRGATConv
#6932)GPSConv
(PyTorch Sparse Tensor support:HANConv
,GATv2Conv
,HGTConv
,GMMConv
,GPSConv
, andRGATConv
#6932)GraphConv
(Add PyTorch SparseTensor support forGINConv
,SAGEConv
, andGraphConv
#6532)GravNetConv
HANConv
(PyTorch Sparse Tensor support:HANConv
,GATv2Conv
,HGTConv
,GMMConv
,GPSConv
, andRGATConv
#6932)HEATConv
HeteroConv
HGTConv
(PyTorch Sparse Tensor support:HANConv
,GATv2Conv
,HGTConv
,GMMConv
,GPSConv
, andRGATConv
#6932)HypergraphConv
LEConv
(PyTorch Sparse Tensor support:LEConv
,LGConv
,NNConv
,PANConv
,SignedConv
, andWLConv
#6936)LGConv
(PyTorch Sparse Tensor support:LEConv
,LGConv
,NNConv
,PANConv
,SignedConv
, andWLConv
#6936)MFConv
NNConv
(PyTorch Sparse Tensor support:HANConv
,GATv2Conv
,HGTConv
,GMMConv
,GPSConv
, andRGATConv
#6932)PANConv
(PyTorch Sparse Tensor support:LEConv
,LGConv
,NNConv
,PANConv
,SignedConv
, andWLConv
#6936)PDNConv
PNAConv
PointNetConv
(Pytorch Sparse Tensor support:PointConv
,PointGNNConv
,PointTransformerConv
,PPFConv
, andResGatedGraphConv
#6937)PointGNNConv
(Pytorch Sparse Tensor support:PointConv
,PointGNNConv
,PointTransformerConv
,PPFConv
, andResGatedGraphConv
#6937)PointTransformerConv
(Pytorch Sparse Tensor support:PointConv
,PointGNNConv
,PointTransformerConv
,PPFConv
, andResGatedGraphConv
#6937)PPFConv
(Pytorch Sparse Tensor support:PointConv
,PointGNNConv
,PointTransformerConv
,PPFConv
, andResGatedGraphConv
#6937)ResGatedGraphConv
(Pytorch Sparse Tensor support:PointConv
,PointGNNConv
,PointTransformerConv
,PPFConv
, andResGatedGraphConv
#6937)RGATConv
(PyTorch Sparse Tensor support:HANConv
,GATv2Conv
,HGTConv
,GMMConv
,GPSConv
, andRGATConv
#6932)RGCNConv
&FastRGCNConv
SAGEConv
(Add PyTorch SparseTensor support forGINConv
,SAGEConv
, andGraphConv
#6532, Pytorch Sparse tensor support:ClusterGCN
,SAGEConv
,AGNNConv
,APPNP
, andFeaStConv
#6874)SGConv
(Add PyTorch SparseTensor support forSGConv
,SSGConv
andTAGConv
#6514)SignedConv
(PyTorch Sparse Tensor support:LEConv
,LGConv
,NNConv
,PANConv
,SignedConv
, andWLConv
#6936)SplineConv
SSGConv
(Add PyTorch SparseTensor support forSGConv
,SSGConv
andTAGConv
#6514)SuperGATConv
TAGConv
(Add PyTorch SparseTensor support forSGConv
,SSGConv
andTAGConv
#6514)TransformerConv
(Pytorch Sparse tensor support:AntiSymmetricConv
,CGConv
, andTransformerConv
#6633)WLConvContinuous
WLConv
(PyTorch Sparse Tensor support:LEConv
,LGConv
,NNConv
,PANConv
,SignedConv
, andWLConv
#6936)XConv
torch_geometric.transforms.*
ToSparseTensor
(Added PyTorch Sparse Tensor support forToSparseTensor
#6930)The text was updated successfully, but these errors were encountered: