Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix c++ interface bug #3613

Merged
merged 12 commits into from
Mar 28, 2024
54 changes: 47 additions & 7 deletions source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "DeepPotPT.h"

#include "common.h"
#include "device.h"
using namespace deepmd;
torch::Tensor createNlistTensor(const std::vector<std::vector<int>>& data) {
std::vector<torch::Tensor> row_tensors;
Expand Down Expand Up @@ -36,15 +37,19 @@
<< std::endl;
return;
}
gpu_id = gpu_rank;
torch::Device device(torch::kCUDA, gpu_rank);
int gpu_num = torch::cuda::device_count();
if (gpu_num > 0) {
gpu_id = gpu_rank % gpu_num;

Check warning on line 42 in source/api_cc/src/DeepPotPT.cc

View check run for this annotation

Codecov / codecov/patch

source/api_cc/src/DeepPotPT.cc#L42

Added line #L42 was not covered by tests
} else {
gpu_id = 0;
}
torch::Device device(torch::kCUDA, gpu_id);
gpu_enabled = torch::cuda::is_available();
if (!gpu_enabled) {
device = torch::Device(torch::kCPU);
std::cout << "load model from: " << model << " to cpu " << gpu_rank
<< std::endl;
std::cout << "load model from: " << model << " to cpu " << std::endl;
} else {
std::cout << "load model from: " << model << " to gpu " << gpu_rank
std::cout << "load model from: " << model << " to gpu " << gpu_id

Check warning on line 52 in source/api_cc/src/DeepPotPT.cc

View check run for this annotation

Codecov / codecov/patch

source/api_cc/src/DeepPotPT.cc#L52

Added line #L52 was not covered by tests
<< std::endl;
}
module = torch::jit::load(model, device);
Expand Down Expand Up @@ -107,7 +112,6 @@
}
auto int_options = torch::TensorOptions().dtype(torch::kInt64);
auto int32_options = torch::TensorOptions().dtype(torch::kInt32);

// select real atoms
std::vector<VALUETYPE> dcoord, dforce, aparam_, datom_energy, datom_virial;
std::vector<int> datype, fwd_map, bkw_map;
Expand All @@ -116,6 +120,25 @@
select_real_atoms_coord(dcoord, datype, aparam_, nghost_real, fwd_map,
bkw_map, nall_real, nloc_real, coord, atype, aparam,
nghost, ntypes, 1, daparam, nall, aparam_nall);
int nloc = nall_real - nghost_real;
int nframes = 1;
if (nloc == 0) {
// no backward map needed
ener.resize(nframes);
// dforce of size nall * 3
force.resize(static_cast<size_t>(nframes) * fwd_map.size() * 3);
fill(force.begin(), force.end(), (VALUETYPE)0.0);
// dvirial of size 9
virial.resize(static_cast<size_t>(nframes) * 9);
fill(virial.begin(), virial.end(), (VALUETYPE)0.0);
// datom_energy_ of size nall
atom_energy.resize(static_cast<size_t>(nframes) * fwd_map.size());
fill(atom_energy.begin(), atom_energy.end(), (VALUETYPE)0.0);
// datom_virial_ of size nall * 9
atom_virial.resize(static_cast<size_t>(nframes) * fwd_map.size() * 9);
fill(atom_virial.begin(), atom_virial.end(), (VALUETYPE)0.0);
return;
}
std::vector<VALUETYPE> coord_wrapped = dcoord;
at::Tensor coord_wrapped_Tensor =
torch::from_blob(coord_wrapped.data(), {1, nall_real, 3}, options)
Expand Down Expand Up @@ -185,7 +208,6 @@
datom_virial.assign(
cpu_atom_virial_.data_ptr<VALUETYPE>(),
cpu_atom_virial_.data_ptr<VALUETYPE>() + cpu_atom_virial_.numel());
int nframes = 1;
// bkw map
force.resize(static_cast<size_t>(nframes) * fwd_map.size() * 3);
atom_energy.resize(static_cast<size_t>(nframes) * fwd_map.size());
Expand Down Expand Up @@ -249,6 +271,24 @@
floatType = torch::kFloat32;
}
auto int_options = torch::TensorOptions().dtype(torch::kInt64);
int nframes = 1;
if (natoms == 0) {
// no backward map needed
ener.resize(nframes);

Check warning on line 277 in source/api_cc/src/DeepPotPT.cc

View check run for this annotation

Codecov / codecov/patch

source/api_cc/src/DeepPotPT.cc#L277

Added line #L277 was not covered by tests
// dforce of size nall * 3
force.resize(static_cast<size_t>(nframes) * natoms * 3);
fill(force.begin(), force.end(), (VALUETYPE)0.0);

Check warning on line 280 in source/api_cc/src/DeepPotPT.cc

View check run for this annotation

Codecov / codecov/patch

source/api_cc/src/DeepPotPT.cc#L279-L280

Added lines #L279 - L280 were not covered by tests
// dvirial of size 9
virial.resize(static_cast<size_t>(nframes) * 9);
fill(virial.begin(), virial.end(), (VALUETYPE)0.0);

Check warning on line 283 in source/api_cc/src/DeepPotPT.cc

View check run for this annotation

Codecov / codecov/patch

source/api_cc/src/DeepPotPT.cc#L282-L283

Added lines #L282 - L283 were not covered by tests
// datom_energy_ of size nall
atom_energy.resize(static_cast<size_t>(nframes) * natoms);
fill(atom_energy.begin(), atom_energy.end(), (VALUETYPE)0.0);

Check warning on line 286 in source/api_cc/src/DeepPotPT.cc

View check run for this annotation

Codecov / codecov/patch

source/api_cc/src/DeepPotPT.cc#L285-L286

Added lines #L285 - L286 were not covered by tests
// datom_virial_ of size nall * 9
atom_virial.resize(static_cast<size_t>(nframes) * natoms * 9);
fill(atom_virial.begin(), atom_virial.end(), (VALUETYPE)0.0);
return;

Check warning on line 290 in source/api_cc/src/DeepPotPT.cc

View check run for this annotation

Codecov / codecov/patch

source/api_cc/src/DeepPotPT.cc#L288-L290

Added lines #L288 - L290 were not covered by tests
}
std::vector<torch::jit::IValue> inputs;
at::Tensor coord_wrapped_Tensor =
torch::from_blob(coord_wrapped.data(), {1, natoms, 3}, options)
Expand Down
Loading