From 9bee6f4cffc7b9878f48f2c7b8cc69ae2034b5df Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 26 Oct 2024 14:44:32 -0400 Subject: [PATCH] use dp api Signed-off-by: Jinzhe Zeng --- source/api_cc/src/DeepPotPT.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index 5d43515e2d..780a8007f3 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -2,7 +2,6 @@ #ifdef BUILD_PYTORCH #include "DeepPotPT.h" -#include #include #include @@ -81,7 +80,9 @@ void DeepPotPT::init(const std::string& model, device = torch::Device(torch::kCPU); std::cout << "load model from: " << model << " to cpu " << std::endl; } else { - c10::cuda::CUDAGuard guard_(device); +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + DPErrcheck(DPSetDevice(gpu_id)); +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM std::cout << "load model from: " << model << " to gpu " << gpu_id << std::endl; }