Skip to content

Commit

Permalink
[EdgeIndex] Do not use SparseTensor in CPU matmul code path (py…
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Nov 30, 2023
1 parent 5fc014e commit b0053ce
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torch_geometric/data/edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def forward(
if input_value is not None:
input_value = input_value.detach()

if torch_geometric.typing.WITH_TORCH_SPARSE:
if other.is_cuda and torch_geometric.typing.WITH_TORCH_SPARSE:
# If `torch-sparse` is available, it still provides a faster
# sparse-dense matmul code path (after all these years...):
rowptr, col = input.get_rowptr(), input[1]
Expand Down Expand Up @@ -781,7 +781,7 @@ def backward(
if input_value is not None:
input_value = input_value.detach()[perm]

if torch_geometric.typing.WITH_TORCH_CLUSTER:
if out_grad.is_cuda and torch_geometric.typing.WITH_TORCH_SPARSE:
other_grad = torch.ops.torch_sparse.spmm_sum( #
None, colptr, row, input_value, None, None, out_grad)

Expand Down

0 comments on commit b0053ce

Please sign in to comment.