Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discrepancy Between Sparse and Dense Convolution Outputs for Density=1.0 #705

Open
linYDTHU opened this issue Jun 21, 2024 · 0 comments
Open

Comments

@linYDTHU
Copy link

linYDTHU commented Jun 21, 2024

Hi @traveller59,

First, I'd like to thank you for your excellent work on spconv. I am currently conducting a sanity check for sparse tensors with a density of 1.0, where I expect the outputs of sparse and dense convolutions to be identical. However, I'm observing significant discrepancies between the two.

Environment Details

  • spconv version: spconv-cu118
  • PyTorch version: 2.3.0+cu118
  • CUDA version: 11.8
  • GPU model: A100 80GB PCIe

Reproducible Example

Here's the code snippet I'm using for the test:

import torch
import torch.nn as nn
import spconv.pytorch as spconv

# Initialize convolution layers
sparse_conv = spconv.SubMConv2d(64, 128, 3, padding=1).cuda().train()
dense_conv = nn.Conv2d(64, 128, 3, padding=1).cuda().train()

# Ensure identical weights and biases
sparse_conv.weight.data = dense_conv.weight.data.permute(0, 2, 3, 1).contiguous().clone().detach().requires_grad_()
sparse_conv.bias.data = dense_conv.bias.data.clone().detach().requires_grad_()

# Prepare inputs
x = torch.randn(10, 64, 64, 64).cuda()
x1 = x.clone().detach().requires_grad_()
x2 = x.clone().detach().requires_grad_()
sparse_input = spconv.SparseConvTensor.from_dense(x1.permute(0, 2, 3, 1).contiguous())
dense_input = x2

# Perform convolutions
sparse_output = sparse_conv(sparse_input)
dense_output = dense_conv(dense_input)

# Check outputs and gradients
print(f"Two Conv share the same outputs: {torch.allclose(sparse_output.dense(), dense_output)}")
print(f"Max relative error: {((sparse_output.dense() - dense_output) / dense_output).abs().max()}")
sparse_output.features.sum().backward()
dense_output.sum().backward()
print(f"Two Conv share the same gradients: {torch.allclose(x1.grad, x2.grad)}")
print(f"Max gradient relative error: {((x1.grad - x2.grad) / x2.grad).abs().max()}")

Observed Results

Two Conv share the same outputs: False
Max relative error: 2254.46240234375
Two Conv share the same gradients: False
Max gradient relative error: 1.746766448020935

Issue

As illustrated, the results and gradients from the sparse convolution significantly differ from those of the dense convolution when the density is set to 1.0. However, the conversions should theoretically yield identical results.
Could you please help identify any potential issues or suggest any modifications to ensure the outputs align more closely? Thank you for your assistance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant