Skip to content

Commit

Permalink
[BugFix] Fix bugs in calling cusparse API (#259)
Browse files Browse the repository at this point in the history
Co-authored-by: fishmingyu <fishmingyu@github.com>
  • Loading branch information
fishmingyu and fishmingyu authored Jul 17, 2021
1 parent 26a2433 commit 6b8494d
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 3 deletions.
6 changes: 4 additions & 2 deletions cogdl/operators/mhspmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@
)
mhtranspose = load(
name="mhtranspose",
extra_cflags=['-lcusparse'],
sources=[os.path.join(path, "spmm/mhTranspose.cpp"), os.path.join(path, "spmm/mhTranspose.cu")],
verbose=False,
verbose=True,
)

spmm = load(
name="spmm",
extra_cflags=['-lcusparse'],
sources=[os.path.join(path, "spmm/spmm.cpp"), os.path.join(path, "spmm/spmm_kernel.cu")],
verbose=False,
verbose=True,
)

def csrmhspmm(rowptr, colind, feat, attention):
Expand Down
3 changes: 2 additions & 1 deletion cogdl/operators/spmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
try:
spmm = load(
name="spmm",
extra_cflags=['-lcusparse'],
sources=[os.path.join(path, "spmm/spmm.cpp"), os.path.join(path, "spmm/spmm_kernel.cu")],
verbose=False,
verbose=True,
)
sddmm = load(
name="sddmm",
Expand Down
1 change: 1 addition & 0 deletions cogdl/operators/spmm/mhTranspose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ void csr2cscKernel(int m, int n, int nnz,
)
{
cusparseHandle_t handle;
checkCuSparseError(cusparseCreate(&handle));
size_t bufferSize = 0;
void* buffer = NULL;
checkCuSparseError(cusparseCsr2cscEx2_bufferSize(handle,
Expand Down
1 change: 1 addition & 0 deletions cogdl/operators/spmm/spmm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ void csr2cscKernel(int m, int n, int nnz,
)
{
cusparseHandle_t handle;
checkCuSparseError(cusparseCreate(&handle));
size_t bufferSize = 0;
void* buffer = NULL;
checkCuSparseError(cusparseCsr2cscEx2_bufferSize(handle,
Expand Down

0 comments on commit 6b8494d

Please sign in to comment.