From a735bed9470f0855a9c1e8e3f57a454c3e00fe5d Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 19 Sep 2023 01:18:06 -0400 Subject: [PATCH] make the pairwise DPRc model 2x faster (#2833) This PR does a trick to speed up the pairwise DPRc model. Considering #2618 is not ready and is quite difficult to implement, in this PR, multiple frames are merged into one frame before feed to `prod_env_mat` OP, and the mesh is faked to make it perform the same behavior as the multiple frames. A new `mesh` shape is proposed. The first element stores `nloc`, and the following 15 elements store nothing to distinguish it from other mesh. The `(16 : 16 + nloc)` elements store `ilist`, `(16 + nloc : 16 + nloc * 2)` store `numneigh`, and the rest elements (in the shape of `sum(numneigh)`) store neighbors. The `nei_mode` is 4 for this situation. `prod_env_mat` OP is not a bottleneck anymore, as shown below. ![image](https://github.com/deepmodeling/deepmd-kit/assets/9496702/eea64b99-d630-4ea1-99f4-e7d49c126c33) --------- Signed-off-by: Jinzhe Zeng --- deepmd/model/pairwise_dprc.py | 55 ++++--- source/lib/include/gpu_rocm.h | 2 +- source/op/pairwise.cc | 200 ++++++++++++++++++++++++- source/op/prod_env_mat_multi_device.cc | 128 +++++++++++++++- source/tests/test_pairwise_dprc.py | 168 ++++++++++++++++++++- 5 files changed, 526 insertions(+), 27 deletions(-) diff --git a/deepmd/model/pairwise_dprc.py b/deepmd/model/pairwise_dprc.py index a9e154096a..8f46ec239d 100644 --- a/deepmd/model/pairwise_dprc.py +++ b/deepmd/model/pairwise_dprc.py @@ -59,6 +59,10 @@ def __init__( compress: Optional[dict] = None, **kwargs, ) -> None: + # internal variable to compare old and new behavior + # expect they give the same results + self.merge_frames = True + super().__init__( type_embedding=type_embedding, type_map=type_map, @@ -151,16 +155,27 @@ def build( atype = tf.reshape(atype_, [nframes, natoms[1], 1]) nframes_qmmm = tf.shape(qmmm_frame_idx)[0] + if self.merge_frames: + ( + forward_qmmm_map, + backward_qmmm_map, + natoms_qmmm, + mesh_qmmm, + ) = op_module.convert_forward_map(forward_qmmm_map, natoms_qmmm, natoms) + coord_qmmm = tf.reshape(coord, [1, -1, 3]) + atype_qmmm = tf.reshape(atype, [1, -1, 1]) + box_qmmm = tf.reshape(box[0], [1, 9]) + else: + mesh_qmmm = make_default_mesh(False, True) + coord_qmmm = tf.gather(coord, qmmm_frame_idx) + atype_qmmm = tf.gather(atype, qmmm_frame_idx) + box_qmmm = tf.gather(box, qmmm_frame_idx) + coord_qm = gather_placeholder(coord, forward_qm_map) atype_qm = gather_placeholder(atype, forward_qm_map, placeholder=-1) - coord_qmmm = gather_placeholder( - tf.gather(coord, qmmm_frame_idx), forward_qmmm_map - ) - atype_qmmm = gather_placeholder( - tf.gather(atype, qmmm_frame_idx), forward_qmmm_map, placeholder=-1 - ) + coord_qmmm = gather_placeholder(coord_qmmm, forward_qmmm_map) + atype_qmmm = gather_placeholder(atype_qmmm, forward_qmmm_map, placeholder=-1) box_qm = box - box_qmmm = tf.gather(box, qmmm_frame_idx) type_embedding = self.typeebd.build( self.ntypes, @@ -189,7 +204,7 @@ def build( atype_qmmm, natoms_qmmm, box_qmmm, - mesh_mixed_type, + mesh_qmmm, input_dict_qmmm, frz_model=frz_model, ckpt_meta=ckpt_meta, @@ -197,10 +212,14 @@ def build( reuse=reuse, ) - energy_qm = qm_dict["energy"] - energy_qmmm = tf.math.segment_sum(qmmm_dict["energy"], qmmm_frame_idx) - energy = energy_qm + energy_qmmm - energy = tf.identity(energy, name="o_energy" + suffix) + if self.merge_frames: + qmmm_dict = qmmm_dict.copy() + sub_nframes = tf.shape(backward_qmmm_map)[0] + qmmm_dict["force"] = tf.tile(qmmm_dict["force"], [sub_nframes, 1]) + qmmm_dict["atom_ener"] = tf.tile(qmmm_dict["atom_ener"], [sub_nframes, 1]) + qmmm_dict["atom_virial"] = tf.tile( + qmmm_dict["atom_virial"], [sub_nframes, 1] + ) force_qm = gather_placeholder( tf.reshape(qm_dict["force"], (nframes, natoms_qm[1], 3)), @@ -218,11 +237,6 @@ def build( force = force_qm + force_qmmm force = tf.reshape(force, (nframes, 3 * natoms[1]), name="o_force" + suffix) - virial_qm = qm_dict["virial"] - virial_qmmm = tf.math.segment_sum(qmmm_dict["virial"], qmmm_frame_idx) - virial = virial_qm + virial_qmmm - virial = tf.identity(virial, name="o_virial" + suffix) - backward_qm_map_nloc = tf.slice(backward_qm_map, [0, 0], [-1, natoms[0]]) backward_qmmm_map_nloc = tf.slice(backward_qmmm_map, [0, 0], [-1, natoms[0]]) atom_ener_qm = gather_placeholder( @@ -255,6 +269,13 @@ def build( atom_virial, (nframes, 9 * natoms[1]), name="o_atom_virial" + suffix ) + energy = tf.reduce_sum(atom_ener, axis=1, name="o_energy" + suffix) + virial = tf.reduce_sum( + tf.reshape(atom_virial, (nframes, natoms[1], 9)), + axis=1, + name="o_virial" + suffix, + ) + model_dict = {} model_dict["energy"] = energy model_dict["force"] = force diff --git a/source/lib/include/gpu_rocm.h b/source/lib/include/gpu_rocm.h index e628d109d7..4c3c1b41a9 100644 --- a/source/lib/include/gpu_rocm.h +++ b/source/lib/include/gpu_rocm.h @@ -59,7 +59,7 @@ void memcpy_host_to_device(FPTYPE *device, const FPTYPE *host, const int size) { } template -void memcpy_device_to_host(FPTYPE *device, std::vector &host) { +void memcpy_device_to_host(const FPTYPE *device, std::vector &host) { DPErrcheck(hipMemcpy(&host[0], device, sizeof(FPTYPE) * host.size(), hipMemcpyDeviceToHost)); } diff --git a/source/op/pairwise.cc b/source/op/pairwise.cc index ee55c3dff3..d60bc3bccc 100644 --- a/source/op/pairwise.cc +++ b/source/op/pairwise.cc @@ -14,6 +14,15 @@ REGISTER_OP("DprcPairwiseIdx") .Output("natoms_qmmm: int32") .Output("qmmm_frame_idx: int32"); +REGISTER_OP("ConvertForwardMap") + .Input("sub_forward_map: int32") + .Input("sub_natoms: int32") + .Input("natoms: int32") + .Output("forward_map: int32") + .Output("backward_map: int32") + .Output("new_natoms: int32") + .Output("mesh: int32"); + using namespace tensorflow; using CPUDevice = Eigen::ThreadPoolDevice; @@ -208,8 +217,193 @@ class PairwiseIdxOp : public OpKernel { } }; +template +class ConvertForwardMapOp : public OpKernel { + public: + explicit ConvertForwardMapOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + deepmd::safe_compute( + context, [this](OpKernelContext* context) { this->_Compute(context); }); + } + + void _Compute(OpKernelContext* context) { + // Grab the input tensor + int tmp_idx = 0; + const Tensor& sub_forward_map_tensor = context->input(tmp_idx++); + const Tensor& sub_natoms_tensor = context->input(tmp_idx++); + const Tensor& natoms_tensor = context->input(tmp_idx++); + + // set size of the sample + OP_REQUIRES(context, (sub_forward_map_tensor.shape().dims() == 2), + errors::InvalidArgument("Dim of idxs should be 2")); + OP_REQUIRES(context, (natoms_tensor.shape().dims() == 1), + errors::InvalidArgument("Dim of natoms should be 1")); + + auto sub_forward_map = sub_forward_map_tensor.matrix(); + int sub_nframes = sub_forward_map_tensor.shape().dim_size(0); + auto sub_natoms = sub_natoms_tensor.vec(); + auto natoms = natoms_tensor.vec(); + int sub_nloc = sub_natoms(0); + int sub_nall = sub_natoms(1); + int nloc = natoms(0); + int nall = natoms(1); + + // merge multiple sub-frames into one frame + // firstly, we need to get the nloc and nghost size to allocate + int new_nloc = 0, new_nghost = 0; + + for (int ii = 0; ii < sub_nframes; ++ii) { + for (int jj = 0; jj < sub_nloc; ++jj) { + if (sub_forward_map(ii, jj) != -1) { + new_nloc++; + } + } + for (int jj = sub_nloc; jj < sub_nall; ++jj) { + if (sub_forward_map(ii, jj) != -1) { + new_nghost++; + } + } + } + if (new_nloc == 0) { + new_nloc = 1; + } + int new_nall = new_nloc + new_nghost; + + // Create an output tensor + TensorShape forward_map_shape; + forward_map_shape.AddDim(1); + forward_map_shape.AddDim(new_nall); + TensorShape backward_map_shape; + // since the atom index can not be repeated, we still need + // to split to multiple frames + backward_map_shape.AddDim(sub_nframes); + backward_map_shape.AddDim(nall); + TensorShape new_natoms_shape; + new_natoms_shape.AddDim(natoms_tensor.shape().dim_size(0)); + + Tensor* forward_map_tensor = NULL; + Tensor* backward_map_tensor = NULL; + Tensor* new_natoms_tensor = NULL; + tmp_idx = 0; + OP_REQUIRES_OK(context, + context->allocate_output(tmp_idx++, forward_map_shape, + &forward_map_tensor)); + OP_REQUIRES_OK(context, + context->allocate_output(tmp_idx++, backward_map_shape, + &backward_map_tensor)); + OP_REQUIRES_OK(context, + context->allocate_output(tmp_idx++, new_natoms_shape, + &new_natoms_tensor)); + + auto forward_map = forward_map_tensor->matrix(); + auto backward_map = backward_map_tensor->matrix(); + auto new_natoms = new_natoms_tensor->vec(); + + // fill -1 in backward_map_tensor + for (int ii = 0; ii < sub_nframes; ++ii) { + for (int jj = 0; jj < nall; ++jj) { + backward_map(ii, jj) = -1; + } + } + + std::vector start_kk(sub_nframes), + end_kk(sub_nframes); // current forward map index + int kk = 0; + // assume nlist to contain all atoms; it should not be a problem for small + // residues + std::vector> jlist(new_nloc); + for (int ii = 0; ii < sub_nframes; ++ii) { + start_kk[ii] = kk; + for (int jj = 0; jj < sub_nloc; ++jj) { + if (sub_forward_map(ii, jj) != -1) { + forward_map(0, kk) = sub_forward_map(ii, jj); + backward_map(ii, sub_forward_map(ii, jj)) = kk; + kk++; + } + } + end_kk[ii] = kk; + // add neighbors to each other + for (int mm = start_kk[ii]; mm < end_kk[ii]; ++mm) { + for (int nn = start_kk[ii]; nn < end_kk[ii]; ++nn) { + if (mm != nn) { + jlist[mm].push_back(nn); + } + } + } + } + for (int ii = 0; ii < sub_nframes; ++ii) { + int start_ghost_kk = kk; + for (int jj = sub_nloc; jj < sub_nall; ++jj) { + if (sub_forward_map(ii, jj) != -1) { + forward_map(0, kk) = sub_forward_map(ii, jj); + backward_map(ii, sub_forward_map(ii, jj)) = kk; + kk++; + } + } + int end_ghost_kk = kk; + // add ghost neighbors to real atoms + for (int mm = start_kk[ii]; mm < end_kk[ii]; ++mm) { + for (int nn = start_ghost_kk; nn < end_ghost_kk; ++nn) { + jlist[mm].push_back(nn); + } + } + } + + // natoms + new_natoms(0) = new_nloc; + new_natoms(1) = new_nall; + new_natoms(2) = new_nloc; + for (int ii = 3; ii < new_natoms.size(); ++ii) { + new_natoms(ii) = 0; + } + + // mesh: + // first element: nloc (a number) + // 2~16: empty (to distinguish from other mesh) + // ilist: nloc + // numneigh: nloc + // jlist: sum(numneigh) + + // calculate numneigh + std::vector numneigh(new_nloc); + for (int ii = 0; ii < new_nloc; ++ii) { + numneigh[ii] = jlist[ii].size(); + } + int size_mesh = + std::accumulate(numneigh.begin(), numneigh.end(), 2 * new_nloc + 16); + + TensorShape mesh_shape; + mesh_shape.AddDim(size_mesh); + Tensor* mesh_tensor = NULL; + OP_REQUIRES_OK( + context, context->allocate_output(tmp_idx++, mesh_shape, &mesh_tensor)); + auto mesh = mesh_tensor->vec(); + mesh(0) = new_nloc; + for (int ii = 1; ii < 16; ++ii) { + mesh(ii) = 0; + } + for (int ii = 0; ii < new_nloc; ++ii) { + mesh(ii + 16) = ii; + } + for (int ii = 0; ii < new_nloc; ++ii) { + mesh(ii + 16 + new_nloc) = numneigh[ii]; + } + kk = 0; + for (int ii = 0; ii < new_nloc; ++ii) { + for (int jj = 0; jj < numneigh[ii]; ++jj) { + mesh(16 + 2 * new_nloc + kk) = jlist[ii][jj]; + kk++; + } + } + } +}; + // Register the CPU kernels. -#define REGISTER_CPU(T) \ - REGISTER_KERNEL_BUILDER(Name("DprcPairwiseIdx").Device(DEVICE_CPU), \ - PairwiseIdxOp); +#define REGISTER_CPU(T) \ + REGISTER_KERNEL_BUILDER(Name("DprcPairwiseIdx").Device(DEVICE_CPU), \ + PairwiseIdxOp); \ + REGISTER_KERNEL_BUILDER(Name("ConvertForwardMap").Device(DEVICE_CPU), \ + ConvertForwardMapOp); REGISTER_CPU(); diff --git a/source/op/prod_env_mat_multi_device.cc b/source/op/prod_env_mat_multi_device.cc index a8882fb5f4..73a0d3c4c1 100644 --- a/source/op/prod_env_mat_multi_device.cc +++ b/source/op/prod_env_mat_multi_device.cc @@ -507,6 +507,9 @@ class ProdEnvMatAOp : public OpKernel { // no pbc assert(nloc == nall); nei_mode = -1; + } else if (mesh_tensor.shape().dim_size(0) > 16) { + // pass neighbor list inside the tensor + nei_mode = 4; } else if (mesh_tensor.shape().dim_size(0) == 7 || mesh_tensor.shape().dim_size(0) == 1) { throw deepmd::deepmd_exception( @@ -799,6 +802,9 @@ class ProdEnvMatROp : public OpKernel { // no pbc assert(nloc == nall); nei_mode = -1; + } else if (mesh_tensor.shape().dim_size(0) > 16) { + // pass neighbor list inside the tensor + nei_mode = 4; } else if (mesh_tensor.shape().dim_size(0) == 7 || mesh_tensor.shape().dim_size(0) == 1) { throw deepmd::deepmd_exception( @@ -1101,14 +1107,15 @@ class ProdEnvMatAMixOp : public OpKernel { } else if (mesh_tensor.shape().dim_size(0) == 6 || mesh_tensor.shape().dim_size(0) == 7) { // manual copied pbc - assert(nloc == nall); nei_mode = 1; b_nlist_map = true; } else if (mesh_tensor.shape().dim_size(0) == 0 || mesh_tensor.shape().dim_size(0) == 1) { // no pbc - assert(nloc == nall); nei_mode = -1; + } else if (mesh_tensor.shape().dim_size(0) > 16) { + // pass neighbor list inside the tensor + nei_mode = 4; } else { throw deepmd::deepmd_exception("invalid mesh tensor"); } @@ -1429,6 +1436,24 @@ static void _map_nei_info_cpu(int* nlist, ntypes, b_nlist_map); } +/** + * @param[in] nei_mode -1, 1, 3, or 4. + * - -1: Build neighbor list without PBC. The size of mesh should + * be 0 (no mixed) or 1 (mixed). + * - 1: Build neighbor list with PBC. The size of mesh should + * be 6 (no mixed) or 7 (mixed). + * - 3:Use neighbor list from given pointers. The size of mesh should be 16. + * The first element is ago (whether update the internal neighbour list). + * The second element is the number of local atoms. The 5th-8th, 9th-12th, + * and 13th-16th elements are the pointer (int*, 4x size of int) to + * ilist, numneigh, firstneigh. The pointer should be valid during the + * execution of this op, so it may be created and given by an external + * program calling the TensorFlow session. + * - 4: Use neighbor list stored in the tensor. The size of mesh should be + * 16 + 2 * nloc + sum(numneigh). Starting from the 17th element, the + * elements are ilist (size of nloc), numneigh (size of nloc), and neighbors + * (size of numneigh[i] for each i). + */ template static void _prepare_coord_nlist_cpu(OpKernelContext* context, FPTYPE const** coord, @@ -1453,7 +1478,7 @@ static void _prepare_coord_nlist_cpu(OpKernelContext* context, const int& max_cpy_trial, const int& max_nnei_trial) { inlist.inum = nloc; - if (nei_mode != 3) { + if (nei_mode != 3 && nei_mode != 4) { // build nlist by myself // normalize and copy coord if (nei_mode == 1) { @@ -1474,6 +1499,19 @@ static void _prepare_coord_nlist_cpu(OpKernelContext* context, inlist.ilist = &ilist[0]; inlist.numneigh = &numneigh[0]; inlist.firstneigh = &firstneigh[0]; + } else if (nei_mode == 4) { + std::memcpy(&ilist[0], 16 + mesh_tensor_data, sizeof(int) * nloc); + std::memcpy(&numneigh[0], 16 + nloc + mesh_tensor_data, sizeof(int) * nloc); + for (int ii = 0, kk = 0; ii < nloc; ++ii) { + jlist[ii].resize(numneigh[ii]); + std::memcpy(&jlist[ii][0], 16 + 2 * nloc + kk + mesh_tensor_data, + sizeof(int) * numneigh[ii]); + firstneigh[ii] = &jlist[ii][0]; + kk += numneigh[ii]; + } + inlist.ilist = &ilist[0]; + inlist.numneigh = &numneigh[0]; + inlist.firstneigh = &firstneigh[0]; } else { // copy pointers to nlist data memcpy(&inlist.ilist, 4 + mesh_tensor_data, sizeof(int*)); @@ -1675,7 +1713,7 @@ static void _prepare_coord_nlist_gpu(OpKernelContext* context, const float& rcut_r, const int& max_cpy_trial, const int& max_nnei_trial) { - if (nei_mode != 3) { + if (nei_mode != 3 && nei_mode != 4) { inlist.inum = nloc; // build nlist by myself // normalize and copy coord @@ -1705,6 +1743,46 @@ static void _prepare_coord_nlist_gpu(OpKernelContext* context, inlist.ilist = ilist; inlist.numneigh = numneigh; inlist.firstneigh = firstneigh; + } else if (nei_mode == 4) { + // TODO: in theory, it will be faster to put everything on GPUs... + std::vector mesh_tensor_data_host(mesh_tensor_size); + std::vector ilist_host(nloc); + std::vector numneigh_host(nloc); + std::vector firstneigh_host(nloc); + std::vector fake_mesh(16); + + // copy from gpu to cpu + deepmd::memcpy_device_to_host(mesh_tensor_data, mesh_tensor_data_host); + std::memcpy(&ilist_host[0], &mesh_tensor_data_host[16], sizeof(int) * nloc); + std::memcpy(&numneigh_host[0], &mesh_tensor_data_host[16 + nloc], + sizeof(int) * nloc); + for (int ii = 0, kk = 0; ii < nloc; ++ii) { + firstneigh_host[ii] = &mesh_tensor_data_host[16 + 2 * nloc + kk]; + kk += numneigh_host[ii]; + } + // make a fake mesh + fake_mesh[0] = 0; + fake_mesh[1] = nloc; + std::memcpy(&fake_mesh[4], &ilist_host, sizeof(int*)); + std::memcpy(&fake_mesh[8], &numneigh_host, sizeof(int*)); + std::memcpy(&fake_mesh[12], &firstneigh_host, sizeof(int**)); + // copy from cpu to gpu + int* fake_mesh_dev = NULL; + deepmd::malloc_device_memory(fake_mesh_dev, 16); + deepmd::memcpy_host_to_device(fake_mesh_dev, fake_mesh); + + deepmd::InputNlist inlist_temp; + inlist_temp.inum = nloc; + // everything should be copied to GPU... + deepmd::env_mat_nbor_update(inlist_temp, inlist, max_nbor_size, + nbor_list_dev, fake_mesh_dev, 16); + OP_REQUIRES(context, (max_numneigh(inlist_temp) <= max_nbor_size), + errors::InvalidArgument( + "Assert failed, max neighbor size of atom(lammps) " + + std::to_string(max_numneigh(inlist_temp)) + + " is larger than " + std::to_string(max_nbor_size) + + ", which currently is not supported by deepmd-kit.")); + deepmd::delete_device_memory(fake_mesh_dev); } else { // update nbor list deepmd::InputNlist inlist_temp; @@ -1908,7 +1986,7 @@ static void _prepare_coord_nlist_gpu_rocm(OpKernelContext* context, const float& rcut_r, const int& max_cpy_trial, const int& max_nnei_trial) { - if (nei_mode != 3) { + if (nei_mode != 3 && nei_mode != 4) { inlist.inum = nloc; // build nlist by myself // normalize and copy coord @@ -1938,6 +2016,46 @@ static void _prepare_coord_nlist_gpu_rocm(OpKernelContext* context, inlist.ilist = ilist; inlist.numneigh = numneigh; inlist.firstneigh = firstneigh; + } else if (nei_mode == 4) { + // TODO: in theory, it will be faster to put everything on GPUs... + std::vector mesh_tensor_data_host(mesh_tensor_size); + std::vector ilist_host(nloc); + std::vector numneigh_host(nloc); + std::vector firstneigh_host(nloc); + std::vector fake_mesh(16); + + // copy from gpu to cpu + deepmd::memcpy_device_to_host(mesh_tensor_data, mesh_tensor_data_host); + std::memcpy(&ilist_host[0], &mesh_tensor_data_host[16], sizeof(int) * nloc); + std::memcpy(&numneigh_host[0], &mesh_tensor_data_host[16 + nloc], + sizeof(int) * nloc); + for (int ii = 0, kk = 0; ii < nloc; ++ii) { + firstneigh_host[ii] = &mesh_tensor_data_host[16 + 2 * nloc + kk]; + kk += numneigh_host[ii]; + } + // make a fake mesh + fake_mesh[0] = 0; + fake_mesh[1] = nloc; + std::memcpy(&fake_mesh[4], &ilist_host, sizeof(int*)); + std::memcpy(&fake_mesh[8], &numneigh_host, sizeof(int*)); + std::memcpy(&fake_mesh[12], &firstneigh_host, sizeof(int**)); + // copy from cpu to gpu + int* fake_mesh_dev = NULL; + deepmd::malloc_device_memory(fake_mesh_dev, 16); + deepmd::memcpy_host_to_device(fake_mesh_dev, fake_mesh); + + deepmd::InputNlist inlist_temp; + inlist_temp.inum = nloc; + // everything should be copied to GPU... + deepmd::env_mat_nbor_update(inlist_temp, inlist, max_nbor_size, + nbor_list_dev, fake_mesh_dev, 16); + OP_REQUIRES(context, (max_numneigh(inlist_temp) <= max_nbor_size), + errors::InvalidArgument( + "Assert failed, max neighbor size of atom(lammps) " + + std::to_string(max_numneigh(inlist_temp)) + + " is larger than " + std::to_string(max_nbor_size) + + ", which currently is not supported by deepmd-kit.")); + deepmd::delete_device_memory(fake_mesh_dev); } else { // update nbor list deepmd::InputNlist inlist_temp; diff --git a/source/tests/test_pairwise_dprc.py b/source/tests/test_pairwise_dprc.py index 0f3f9fad50..e95b66c7a0 100644 --- a/source/tests/test_pairwise_dprc.py +++ b/source/tests/test_pairwise_dprc.py @@ -97,6 +97,169 @@ def test_op_single_frame(self): np.testing.assert_array_equal(qmmm_frame_idx, np.array([0, 0, 0], dtype=int)) +class TestConvertForwardMapOP(tf.test.TestCase): + """Test convert_forward_map OP.""" + + def test_convert_forward_map(self): + forward_qmmm_map = np.array( + [ + [3, 4, 0, 1, 2, 10, 11], + [3, 4, 5, 6, 7, 10, -1], + [3, 4, 8, 9, -1, 10, -1], + ], + dtype=int, + ) + natoms_qmmm = np.array([5, 7, 5], dtype=int) + natoms = np.array([10, 12, 10], dtype=int) + with self.cached_session() as sess: + ( + forward_qmmm_map, + backward_qmmm_map, + natoms_qmmm, + mesh_qmmm, + ) = run_sess( + sess, + op_module.convert_forward_map(forward_qmmm_map, natoms_qmmm, natoms), + ) + np.testing.assert_array_equal( + forward_qmmm_map, + np.array([[3, 4, 0, 1, 2, 3, 4, 5, 6, 7, 3, 4, 8, 9, 10, 11, 10, 10]]), + ) + np.testing.assert_array_equal( + backward_qmmm_map, + np.array( + [ + [2, 3, 4, 0, 1, -1, -1, -1, -1, -1, 14, 15], + [-1, -1, -1, 5, 6, 7, 8, 9, -1, -1, 16, -1], + [-1, -1, -1, 10, 11, -1, -1, -1, 12, 13, 17, -1], + ] + ), + ) + np.testing.assert_array_equal(natoms_qmmm, np.array([14, 18, 14], dtype=int)) + np.testing.assert_array_equal( + mesh_qmmm, + np.array( + [ + 14, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 6, + 6, + 6, + 6, + 6, + 5, + 5, + 5, + 5, + 5, + 4, + 4, + 4, + 4, + 1, + 2, + 3, + 4, + 14, + 15, + 0, + 2, + 3, + 4, + 14, + 15, + 0, + 1, + 3, + 4, + 14, + 15, + 0, + 1, + 2, + 4, + 14, + 15, + 0, + 1, + 2, + 3, + 14, + 15, + 6, + 7, + 8, + 9, + 16, + 5, + 7, + 8, + 9, + 16, + 5, + 6, + 8, + 9, + 16, + 5, + 6, + 7, + 9, + 16, + 5, + 6, + 7, + 8, + 16, + 11, + 12, + 13, + 17, + 10, + 12, + 13, + 17, + 10, + 11, + 13, + 17, + 10, + 11, + 12, + 17, + ] + ), + ) + + @unittest.skipIf( parse_version(tf.__version__) < parse_version("1.15"), f"The current tf version {tf.__version__} is too low to run the new testing model.", @@ -291,6 +454,7 @@ def test_model_ener(self): input_dict["aparam"] = t_aparam model.data_stat(data) + # model.merge_frames = False model_pred = model.build( t_coord, t_type, @@ -298,7 +462,7 @@ def test_model_ener(self): t_box, t_mesh, input_dict, - suffix="se_a_atom_ener_0", + suffix="pairwise_dprc_0", reuse=False, ) energy = model_pred["energy"] @@ -354,6 +518,8 @@ def test_model_ener(self): # the model is pairwise! self.assertAllClose(e[1] + e[2] + e[3] - 3 * e[0], e[4] - e[0]) self.assertAllClose(f[1] + f[2] + f[3] - 3 * f[0], f[4] - f[0]) + self.assertAllClose(e[0], 0.189075, 1e-6) + self.assertAllClose(f[0, 0], 0.060047, 1e-6) def test_nloc(self): jfile = tests_path / "pairwise_dprc.json"