Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix set device #316

Merged
merged 1 commit into from
Dec 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion cogdl/operators/spmm/mhTranspose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <iostream>
#include <vector>
#include <pybind11/pybind11.h>
#include <c10/cuda/CUDAGuard.h>

torch::Tensor mhtranspose_cuda(
torch::Tensor permute,
Expand All @@ -18,6 +19,8 @@ torch::Tensor mhtranspose(
assert(attention.is_contiguous());
assert(permute.dtype() == torch::kInt32);
assert(attention.dtype() == torch::kFloat32);
const at::cuda::OptionalCUDAGuard device_guard1(device_of(permute));
const at::cuda::OptionalCUDAGuard device_guard2(device_of(attention));
return mhtranspose_cuda(permute, attention);
}

Expand All @@ -40,6 +43,9 @@ std::vector<torch::Tensor> csr2csc(
assert(rowptr.dtype() == torch::kInt32);
assert(colind.dtype() == torch::kInt32);
assert(csr_data.dtype() == torch::kInt32);
const at::cuda::OptionalCUDAGuard device_guard1(device_of(rowptr));
const at::cuda::OptionalCUDAGuard device_guard2(device_of(colind));
const at::cuda::OptionalCUDAGuard device_guard3(device_of(csr_data));
return csr2csc_cuda(rowptr, colind, csr_data);
}

Expand All @@ -48,4 +54,4 @@ PYBIND11_MODULE(mhtranspose, m)
m.doc() = "mhtranspose in CSR format. ";
m.def("mhtranspose", &mhtranspose, "CSR mhsddmm");
m.def("csr2csc", &csr2csc, "csr2csc");
}
}
7 changes: 6 additions & 1 deletion cogdl/operators/spmm/multiheadSddmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <iostream>
#include <vector>
#include <pybind11/pybind11.h>
#include <c10/cuda/CUDAGuard.h>

torch::Tensor mhsddmm_cuda(
torch::Tensor rowptr,
Expand Down Expand Up @@ -29,11 +30,15 @@ torch::Tensor mhsddmm(
assert(colind.dtype() == torch::kInt32);
assert(grad.dtype() == torch::kFloat32);
assert(feature.dtype() == torch::kFloat32);
const at::cuda::OptionalCUDAGuard device_guard1(device_of(rowptr));
const at::cuda::OptionalCUDAGuard device_guard2(device_of(colind));
const at::cuda::OptionalCUDAGuard device_guard3(device_of(grad));
const at::cuda::OptionalCUDAGuard device_guard4(device_of(feature));
return mhsddmm_cuda(rowptr, colind, grad, feature);
}

PYBIND11_MODULE(mhsddmm, m)
{
m.doc() = "mhsddmm in CSR format. ";
m.def("mhsddmm", &mhsddmm, "CSR mhsddmm");
}
}
7 changes: 6 additions & 1 deletion cogdl/operators/spmm/multiheadSpmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <iostream>
#include <vector>
#include <pybind11/pybind11.h>
#include <c10/cuda/CUDAGuard.h>

torch::Tensor mhspmm_cuda(
torch::Tensor rowptr,
Expand All @@ -27,11 +28,15 @@ torch::Tensor mhspmm(
assert(colind.dtype() == torch::kInt32);
assert(attention.dtype() == torch::kFloat32);
assert(infeat.dtype() == torch::kFloat32);
const at::cuda::OptionalCUDAGuard device_guard1(device_of(rowptr));
const at::cuda::OptionalCUDAGuard device_guard2(device_of(colind));
const at::cuda::OptionalCUDAGuard device_guard3(device_of(attention));
const at::cuda::OptionalCUDAGuard device_guard4(device_of(infeat));
return mhspmm_cuda(rowptr, colind, attention, infeat);
}

PYBIND11_MODULE(mhspmm, m)
{
m.doc() = "mhtranspose in CSR format. ";
m.def("mhspmm", &mhspmm, "CSR mhsddmm");
}
}
11 changes: 10 additions & 1 deletion cogdl/operators/spmm/sddmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <iostream>
#include <vector>
#include <pybind11/pybind11.h>
#include <c10/cuda/CUDAGuard.h>

torch::Tensor sddmm_cuda_coo(
torch::Tensor rowind,
Expand Down Expand Up @@ -36,6 +37,10 @@ torch::Tensor coo_sddmm(
assert(colind.dtype()==torch::kInt32);
assert(D1.dtype()==torch::kFloat32);
assert(D2.dtype()==torch::kFloat32);
const at::cuda::OptionalCUDAGuard device_guard1(device_of(rowind));
const at::cuda::OptionalCUDAGuard device_guard2(device_of(colind));
const at::cuda::OptionalCUDAGuard device_guard3(device_of(D1));
const at::cuda::OptionalCUDAGuard device_guard4(device_of(D2));
return sddmm_cuda_coo(rowind, colind, D1, D2);
}

Expand All @@ -57,6 +62,10 @@ torch::Tensor csr_sddmm(
assert(colind.dtype()==torch::kInt32);
assert(D1.dtype()==torch::kFloat32);
assert(D2.dtype()==torch::kFloat32);
const at::cuda::OptionalCUDAGuard device_guard1(device_of(rowptr));
const at::cuda::OptionalCUDAGuard device_guard2(device_of(colind));
const at::cuda::OptionalCUDAGuard device_guard3(device_of(D1));
const at::cuda::OptionalCUDAGuard device_guard4(device_of(D2));
return sddmm_cuda_csr(rowptr, colind, D1, D2);
}

Expand All @@ -65,4 +74,4 @@ PYBIND11_MODULE(sddmm, m)
m.doc() = "SDDMM kernel. Format of COO and CSR are provided.";
m.def("coo_sddmm", &coo_sddmm, "COO SDDMM");
m.def("csr_sddmm", &csr_sddmm, "CSR SDDMM");
}
}
13 changes: 12 additions & 1 deletion cogdl/operators/spmm/spmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <iostream>
#include <vector>
#include <pybind11/pybind11.h>
#include <c10/cuda/CUDAGuard.h>

torch::Tensor spmm_cuda(
torch::Tensor rowptr,
Expand Down Expand Up @@ -36,6 +37,10 @@ torch::Tensor csr_spmm(
assert(A_colind.dtype() == torch::kInt32);
assert(A_csrVal.dtype() == torch::kFloat32);
assert(B.dtype() == torch::kFloat32);
const at::cuda::OptionalCUDAGuard device_guard1(device_of(A_rowptr));
const at::cuda::OptionalCUDAGuard device_guard2(device_of(A_colind));
const at::cuda::OptionalCUDAGuard device_guard3(device_of(A_csrVal));
const at::cuda::OptionalCUDAGuard device_guard4(device_of(B));
return spmm_cuda(A_rowptr, A_colind, A_csrVal, B);
}

Expand All @@ -53,6 +58,9 @@ torch::Tensor csr_spmm_no_edge_value(
assert(A_rowptr.dtype() == torch::kInt32);
assert(A_colind.dtype() == torch::kInt32);
assert(B.dtype() == torch::kFloat32);
const at::cuda::OptionalCUDAGuard device_guard1(device_of(A_rowptr));
const at::cuda::OptionalCUDAGuard device_guard2(device_of(A_colind));
const at::cuda::OptionalCUDAGuard device_guard3(device_of(B));
return spmm_cuda_no_edge_value(A_rowptr, A_colind, B);
}

Expand All @@ -75,6 +83,9 @@ std::vector<torch::Tensor> csr2csc(
assert(rowptr.dtype() == torch::kInt32);
assert(colind.dtype() == torch::kInt32);
assert(csr_data.dtype() == torch::kFloat32);
const at::cuda::OptionalCUDAGuard device_guard1(device_of(rowptr));
const at::cuda::OptionalCUDAGuard device_guard2(device_of(colind));
const at::cuda::OptionalCUDAGuard device_guard3(device_of(csr_data));
return csr2csc_cuda(rowptr, colind, csr_data);
}

Expand All @@ -84,4 +95,4 @@ PYBIND11_MODULE(spmm, m)
m.def("csr_spmm", &csr_spmm, "CSR SPMM");
m.def("csr_spmm_no_edge_value", &csr_spmm_no_edge_value, "CSR SPMM NO EDGE VALUE");
m.def("csr2csc", &csr2csc, "csr2csc");
}
}