Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add fast spmm (cpu) #312

Merged
merged 1 commit into from
Nov 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions cogdl/operators/spmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,19 @@ def csrspmm(rowptr, colind, x, csr_data, sym=False, actnn=False):
csrspmm = None


try:
spmm_cpu = load(
name="spmm_cpu",
extra_cflags=["-fopenmp"],
sources=[os.path.join(path, "spmm/spmm_cpu.cpp")],
verbose=False,
)
spmm_cpu = spmm_cpu.csr_spmm_cpu
except Exception as e:
print(e)
spmm_cpu = None


class SPMMFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, rowptr, colind, feat, edge_weight_csr=None, sym=False):
Expand Down
64 changes: 64 additions & 0 deletions cogdl/operators/spmm/spmm_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include <torch/extension.h>
#include <iostream>
#include <vector>
#include <pybind11/pybind11.h>

torch::Tensor spmm_cpu(
torch::Tensor rowptr,
torch::Tensor colind,
torch::Tensor values,
torch::Tensor dense)
{
const auto m = rowptr.size(0)-1;
const auto k = dense.size(1);
auto devid = dense.device().index();
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU, devid);
auto out = torch::empty({m,k}, options);

int *rowptr_ptr = rowptr.data_ptr<int>();
int *colind_ptr = colind.data_ptr<int>();
float *values_ptr = values.data_ptr<float>();
float *dense_ptr = dense.data_ptr<float>();
float *out_ptr = out.data_ptr<float>();

#pragma omp parallel for schedule(dynamic)
for (int i = 0; i < m; ++i) {
int row_start = rowptr_ptr[i], row_end = rowptr_ptr[i + 1];
int ik = i * k;
for (int key = row_start; key < row_end; ++key) {
int j = colind_ptr[key] * k;
float val = values_ptr[key];
for (int t = 0; t < k; ++t) {
out_ptr[ik + t] += val * dense_ptr[j + t];
}
}
}
return out;
}

torch::Tensor csr_spmm_cpu(
torch::Tensor A_rowptr,
torch::Tensor A_colind,
torch::Tensor A_csrVal,
torch::Tensor B)
{
assert(A_rowptr.device().type() == torch::kCPU);
assert(A_colind.device().type() == torch::kCPU);
assert(A_csrVal.device().type() == torch::kCPU);
assert(B.device().type() == torch::kCPU);
assert(A_rowptr.is_contiguous());
assert(A_colind.is_contiguous());
assert(A_csrVal.is_contiguous());
assert(B.is_contiguous());
assert(A_rowptr.dtype() == torch::kInt32);
assert(A_colind.dtype() == torch::kInt32);
assert(A_csrVal.dtype() == torch::kFloat32);
assert(B.dtype() == torch::kFloat32);
return spmm_cpu(A_rowptr, A_colind, A_csrVal, B);
}

PYBIND11_MODULE(spmm_cpu, m)
{
m.doc() = "spmm_cpu in CSR format.";
m.def("csr_spmm_cpu", &csr_spmm_cpu, "CSR SPMM (CPU)");
}
42 changes: 42 additions & 0 deletions cogdl/utils/spmm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
"csrmhspmm": None,
"csr_edge_softmax": None,
"fused_gat_func": None,
"fast_spmm_cpu": None,
"spmm_flag": False,
"mh_spmm_flag": False,
"fused_gat_flag": False,
"spmm_cpu_flag": False,
}


Expand All @@ -28,6 +30,16 @@ def initialize_spmm():
# print("Failed to load fast version of SpMM, use torch.scatter_add instead.")


def initialize_spmm_cpu():
if CONFIGS["spmm_cpu_flag"]:
return
CONFIGS["spmm_cpu_flag"] = True

from cogdl.operators.spmm import spmm_cpu

CONFIGS["fast_spmm_cpu"] = spmm_cpu


def spmm_scatter(row, col, values, b):
r"""
Args:
Expand All @@ -40,6 +52,36 @@ def spmm_scatter(row, col, values, b):
return output


def spmm_cpu(graph, x, fast_spmm_cpu=None):
if fast_spmm_cpu is None:
initialize_spmm_cpu()
fast_spmm_cpu = CONFIGS["fast_spmm_cpu"]
if fast_spmm_cpu is not None and str(x.device) == "cpu":
if graph.out_norm is not None:
x = graph.out_norm * x

row_ptr, col_indices = graph.row_indptr, graph.col_indices
csr_data = graph.raw_edge_weight
x = fast_spmm_cpu(row_ptr.int(), col_indices.int(), csr_data, x)

if graph.in_norm is not None:
x = graph.in_norm * x
else:
row, col = graph.edge_index
x = spmm_scatter(row, col, graph.edge_weight, x)
return x


class SpMM_CPU(torch.nn.Module):
def __init__(self):
super().__init__()
initialize_spmm_cpu()
self.fast_spmm_cpu = CONFIGS["fast_spmm_cpu"]

def forward(self, graph, x):
return spmm_cpu(graph, x, self.fast_spmm_cpu)


def spmm(graph, x, actnn=False, fast_spmm=None):
if fast_spmm is None:
initialize_spmm()
Expand Down
Loading