Skip to content

Commit

Permalink
prod_force: support multiple frames in parallel (#2600)
Browse files Browse the repository at this point in the history
The previous `prod_force` did not support multiple frames in parallel,
which was slow when the batch size was large.

This PR adds support so that prod_force can be parallelized in the
dimension of the samples.

When the batch size is about 70, the `prod_force` op is 10x faster than
before on GPU cards.

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Jun 12, 2023
1 parent 046a5a4 commit bb0d02b
Show file tree
Hide file tree
Showing 7 changed files with 272 additions and 128 deletions.
61 changes: 55 additions & 6 deletions source/lib/include/prod_force.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,19 @@

namespace deepmd {

/**
* @brief Produce force from net_deriv and in_deriv.
*
* @tparam FPTYPE float or double
* @param[out] force Atomic forces.
* @param[in] net_deriv Net derivative.
* @param[in] in_deriv Environmental derivative.
* @param[in] nlist Neighbor list.
* @param[in] nloc The number of local atoms.
* @param[in] nall The number of all atoms, including ghost atoms.
* @param[in] nnei The number of neighbors.
* @param[in] nframes The number of frames.
*/
template <typename FPTYPE>
void prod_force_a_cpu(FPTYPE* force,
const FPTYPE* net_deriv,
Expand All @@ -10,7 +23,38 @@ void prod_force_a_cpu(FPTYPE* force,
const int nloc,
const int nall,
const int nnei,
const int start_index = 0);
const int nframes);

/**
* @brief Produce force from net_deriv and in_deriv.
* @details This function is used for multi-threading. Only part of atoms
* are computed in this thread. They will be comptued in parallel.
*
* @tparam FPTYPE float or double
* @param[out] force Atomic forces.
* @param[in] net_deriv Net derivative.
* @param[in] in_deriv Environmental derivative.
* @param[in] nlist Neighbor list.
* @param[in] nloc The number of local atoms.
* @param[in] nall The number of all atoms, including ghost atoms.
* @param[in] nnei The number of neighbors.
* @param[in] nframes The number of frames.
* @param[in] thread_nloc The number of local atoms to be computed in this
* thread.
* @param[in] thread_start_index The start index of local atoms to be computed
* in this thread. The index should be in [0, nloc).
*/
template <typename FPTYPE>
void prod_force_a_cpu(FPTYPE* force,
const FPTYPE* net_deriv,
const FPTYPE* in_deriv,
const int* nlist,
const int nloc,
const int nall,
const int nnei,
const int nframes,
const int thread_nloc,
const int thread_start_index);

template <typename FPTYPE>
void prod_force_r_cpu(FPTYPE* force,
Expand All @@ -19,7 +63,8 @@ void prod_force_r_cpu(FPTYPE* force,
const int* nlist,
const int nloc,
const int nall,
const int nnei);
const int nnei,
const int nframes);

#if GOOGLE_CUDA
template <typename FPTYPE>
Expand All @@ -29,7 +74,8 @@ void prod_force_a_gpu_cuda(FPTYPE* force,
const int* nlist,
const int nloc,
const int nall,
const int nnei);
const int nnei,
const int nframes);

template <typename FPTYPE>
void prod_force_r_gpu_cuda(FPTYPE* force,
Expand All @@ -38,7 +84,8 @@ void prod_force_r_gpu_cuda(FPTYPE* force,
const int* nlist,
const int nloc,
const int nall,
const int nnei);
const int nnei,
const int nframes);
#endif // GOOGLE_CUDA

#if TENSORFLOW_USE_ROCM
Expand All @@ -49,7 +96,8 @@ void prod_force_a_gpu_rocm(FPTYPE* force,
const int* nlist,
const int nloc,
const int nall,
const int nnei);
const int nnei,
const int nframes);

template <typename FPTYPE>
void prod_force_r_gpu_rocm(FPTYPE* force,
Expand All @@ -58,7 +106,8 @@ void prod_force_r_gpu_rocm(FPTYPE* force,
const int* nlist,
const int nloc,
const int nall,
const int nnei);
const int nnei,
const int nframes);
#endif // TENSORFLOW_USE_ROCM

} // namespace deepmd
59 changes: 37 additions & 22 deletions source/lib/src/cuda/prod_force.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ template <typename FPTYPE, int THREADS_PER_BLOCK>
__global__ void force_deriv_wrt_center_atom(FPTYPE* force,
const FPTYPE* net_deriv,
const FPTYPE* in_deriv,
const int ndescrpt) {
const int ndescrpt,
const int nloc,
const int nall) {
__shared__ FPTYPE data[THREADS_PER_BLOCK * 3];
int_64 bid = blockIdx.x;
unsigned int tid = threadIdx.x;
Expand All @@ -31,10 +33,13 @@ __global__ void force_deriv_wrt_center_atom(FPTYPE* force,
__syncthreads();
}
// write result for this block to global memory
const int_64 kk = bid / nloc; // frame index
const int_64 ll = bid % nloc; // atom index
const int_64 i_idx_nall = kk * nall + ll;
if (tid == 0) {
force[bid * 3 + 0] -= data[THREADS_PER_BLOCK * 0];
force[bid * 3 + 1] -= data[THREADS_PER_BLOCK * 1];
force[bid * 3 + 2] -= data[THREADS_PER_BLOCK * 2];
force[i_idx_nall * 3 + 0] -= data[THREADS_PER_BLOCK * 0];
force[i_idx_nall * 3 + 1] -= data[THREADS_PER_BLOCK * 1];
force[i_idx_nall * 3 + 2] -= data[THREADS_PER_BLOCK * 2];
}
}

Expand All @@ -44,6 +49,7 @@ __global__ void force_deriv_wrt_neighbors_a(FPTYPE* force,
const FPTYPE* in_deriv,
const int* nlist,
const int nloc,
const int nall,
const int nnei) {
// idy -> nnei
const int_64 idx = blockIdx.x;
Expand All @@ -63,7 +69,8 @@ __global__ void force_deriv_wrt_neighbors_a(FPTYPE* force,
force_tmp += net_deriv[idx * ndescrpt + idy * 4 + idw] *
in_deriv[idx * ndescrpt * 3 + (idy * 4 + idw) * 3 + idz];
}
atomicAdd(force + j_idx * 3 + idz, force_tmp);
const int_64 kk = idx / nloc; // frame index
atomicAdd(force + kk * nall * 3 + j_idx * 3 + idz, force_tmp);
}

template <typename FPTYPE>
Expand All @@ -72,6 +79,7 @@ __global__ void force_deriv_wrt_neighbors_r(FPTYPE* force,
const FPTYPE* in_deriv,
const int* nlist,
const int nloc,
const int nall,
const int nnei) {
// idy -> nnei
const int_64 idx = blockIdx.x;
Expand All @@ -86,7 +94,8 @@ __global__ void force_deriv_wrt_neighbors_r(FPTYPE* force,
if (j_idx < 0) {
return;
}
atomicAdd(force + j_idx * 3 + idz,
const int_64 kk = idx / nloc; // frame index
atomicAdd(force + kk * nall * 3 + j_idx * 3 + idz,
net_deriv[idx * ndescrpt + idy] *
in_deriv[idx * ndescrpt * 3 + idy * 3 + idz]);
}
Expand All @@ -99,21 +108,22 @@ void prod_force_a_gpu_cuda(FPTYPE* force,
const int* nlist,
const int nloc,
const int nall,
const int nnei) {
const int nnei,
const int nframes) {
const int ndescrpt = nnei * 4;
DPErrcheck(cudaMemset(force, 0, sizeof(FPTYPE) * nall * 3));
DPErrcheck(cudaMemset(force, 0, sizeof(FPTYPE) * nframes * nall * 3));

force_deriv_wrt_center_atom<FPTYPE, TPB>
<<<nloc, TPB>>>(force, net_deriv, in_deriv, ndescrpt);
force_deriv_wrt_center_atom<FPTYPE, TPB><<<nframes * nloc, TPB>>>(
force, net_deriv, in_deriv, ndescrpt, nloc, nall);
DPErrcheck(cudaGetLastError());
DPErrcheck(cudaDeviceSynchronize());

const int LEN = 64;
const int nblock = (nnei + LEN - 1) / LEN;
dim3 block_grid(nloc, nblock);
dim3 block_grid(nframes * nloc, nblock);
dim3 thread_grid(LEN, 3);
force_deriv_wrt_neighbors_a<<<block_grid, thread_grid>>>(
force, net_deriv, in_deriv, nlist, nloc, nnei);
force, net_deriv, in_deriv, nlist, nloc, nall, nnei);
DPErrcheck(cudaGetLastError());
DPErrcheck(cudaDeviceSynchronize());
}
Expand All @@ -125,21 +135,22 @@ void prod_force_r_gpu_cuda(FPTYPE* force,
const int* nlist,
const int nloc,
const int nall,
const int nnei) {
const int nnei,
const int nframes) {
const int ndescrpt = nnei * 1;
DPErrcheck(cudaMemset(force, 0, sizeof(FPTYPE) * nall * 3));
DPErrcheck(cudaMemset(force, 0, sizeof(FPTYPE) * nframes * nall * 3));

force_deriv_wrt_center_atom<FPTYPE, TPB>
<<<nloc, TPB>>>(force, net_deriv, in_deriv, ndescrpt);
force_deriv_wrt_center_atom<FPTYPE, TPB><<<nframes * nloc, TPB>>>(
force, net_deriv, in_deriv, ndescrpt, nloc, nall);
DPErrcheck(cudaGetLastError());
DPErrcheck(cudaDeviceSynchronize());

const int LEN = 64;
const int nblock = (nnei + LEN - 1) / LEN;
dim3 block_grid(nloc, nblock);
dim3 block_grid(nframes * nloc, nblock);
dim3 thread_grid(LEN, 3);
force_deriv_wrt_neighbors_r<<<block_grid, thread_grid>>>(
force, net_deriv, in_deriv, nlist, nloc, nnei);
force, net_deriv, in_deriv, nlist, nloc, nall, nnei);
DPErrcheck(cudaGetLastError());
DPErrcheck(cudaDeviceSynchronize());
}
Expand All @@ -150,26 +161,30 @@ template void prod_force_a_gpu_cuda<float>(float* force,
const int* nlist,
const int nloc,
const int nall,
const int nnei);
const int nnei,
const int nframes);
template void prod_force_a_gpu_cuda<double>(double* force,
const double* net_deriv,
const double* in_deriv,
const int* nlist,
const int nloc,
const int nall,
const int nnei);
const int nnei,
const int nframes);
template void prod_force_r_gpu_cuda<float>(float* force,
const float* net_deriv,
const float* in_deriv,
const int* nlist,
const int nloc,
const int nall,
const int nnei);
const int nnei,
const int nframes);
template void prod_force_r_gpu_cuda<double>(double* force,
const double* net_deriv,
const double* in_deriv,
const int* nlist,
const int nloc,
const int nall,
const int nnei);
const int nnei,
const int nframes);
} // namespace deepmd
Loading

0 comments on commit bb0d02b

Please sign in to comment.