Skip to content

Commit

Permalink
support neighbor stat on GPUs (#2897)
Browse files Browse the repository at this point in the history
Fix #2619.

The GPU implementation in this PR is usually faster than the CPU in one
thread (i.e., not using the feature implemented in #1624). Still, it
needs parallelism in the batch dimension, which is blocked by #2618,
regarding building the neighbor list. The GPU utilization is less than
10% for the water system. It should be improved when #2618 makes
progress.

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Oct 7, 2023
1 parent d0edb3a commit da100dc
Show file tree
Hide file tree
Showing 6 changed files with 389 additions and 144 deletions.
12 changes: 6 additions & 6 deletions deepmd/utils/neighbor_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def builder():
rcut=self.rcut,
)
place_holders["dir"] = tf.placeholder(tf.string)
_min_nbor_dist = tf.reduce_min(_min_nbor_dist)
_max_nbor_size = tf.reduce_max(_max_nbor_size, axis=0)
return place_holders, (_max_nbor_size, _min_nbor_dist, place_holders["dir"])

with sub_graph.as_default():
Expand Down Expand Up @@ -128,10 +130,7 @@ def feed():
}

for mn, dt, jj in self.p.generate(self.sub_sess, feed()):
if dt.size != 0:
dt = np.min(dt)
else:
dt = self.rcut
if np.isinf(dt):
log.warning(
"Atoms with no neighbors found in %s. Please make sure it's what you expected."
% jj
Expand All @@ -145,9 +144,10 @@ def feed():
" training data to remove duplicated atoms." % jj
)
self.min_nbor_dist = dt
var = np.max(mn, axis=0)
self.max_nbor_size = np.maximum(var, self.max_nbor_size)
self.max_nbor_size = np.maximum(mn, self.max_nbor_size)

# do sqrt in the final
self.min_nbor_dist = math.sqrt(self.min_nbor_dist)
log.info("training data with min nbor dist: " + str(self.min_nbor_dist))
log.info("training data with max nbor size: " + str(self.max_nbor_size))
return self.min_nbor_dist, self.max_nbor_size
18 changes: 18 additions & 0 deletions source/lib/include/neighbor_stat.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// SPDX-License-Identifier: LGPL-3.0-or-later
#include "neighbor_list.h"

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

namespace deepmd {
template <typename FPTYPE>
void neighbor_stat_gpu(const FPTYPE* coord,
const int* type,
const int nloc,
const deepmd::InputNlist& gpu_nlist,
int* max_nbor_size,
FPTYPE* min_nbor_dist,
const int ntypes,
const int MAX_NNEI);
} // namespace deepmd

#endif
103 changes: 103 additions & 0 deletions source/lib/src/gpu/neighbor_stat.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#include <cmath>

#include "device.h"
#include "neighbor_list.h"

template <typename FPTYPE>
__global__ void neighbor_stat_g(const FPTYPE* coord,
const int* type,
const int nloc,
const int* ilist,
int** firstneigh,
const int* numneigh,
int* max_nbor_size,
FPTYPE* min_nbor_dist,
const int ntypes,
const int MAX_NNEI) {
int ithread = blockIdx.x * blockDim.x + threadIdx.x;
int ii = ithread / MAX_NNEI;
int jj = ithread % MAX_NNEI;
// assume the same block has the same ii
__shared__ int cache[TPB];
cache[threadIdx.x] = 0;
if (ii >= nloc) {
return;
}
int idx_i = ilist[ii];
if (type[idx_i] < 0) {
// set all to 10000
min_nbor_dist[ii * MAX_NNEI + jj] = INFINITY;
return; // virtual atom
}
if (jj < numneigh[ii]) {
int idx_j = firstneigh[ii][jj];
int type_j = type[idx_j];
if (type_j < 0) {
min_nbor_dist[ii * MAX_NNEI + jj] = INFINITY;
return; // virtual atom
}
__syncthreads();
FPTYPE rij[3] = {coord[idx_j * 3 + 0] - coord[idx_i * 3 + 0],
coord[idx_j * 3 + 1] - coord[idx_i * 3 + 1],
coord[idx_j * 3 + 2] - coord[idx_i * 3 + 2]};
// we do not need to use the real index
// we do not need to do slow sqrt for every dist; instead do sqrt in the
// final
min_nbor_dist[ii * MAX_NNEI + jj] =
rij[0] * rij[0] + rij[1] * rij[1] + rij[2] * rij[2];

// atomicAdd(max_nbor_size + ii * ntypes + type_j, 1);
// See https://www.cnblogs.com/neopenx/p/4705320.html
atomicAdd(&cache[type_j], 1);
__syncthreads();
if (threadIdx.x < ntypes) {
atomicAdd(&max_nbor_size[ii * ntypes + threadIdx.x], cache[threadIdx.x]);
}
} else {
// set others to 10000
min_nbor_dist[ii * MAX_NNEI + jj] = INFINITY;
}
}

namespace deepmd {

template <typename FPTYPE>
void neighbor_stat_gpu(const FPTYPE* coord,
const int* type,
const int nloc,
const deepmd::InputNlist& gpu_nlist,
int* max_nbor_size,
FPTYPE* min_nbor_dist,
const int ntypes,
const int MAX_NNEI) {
DPErrcheck(gpuGetLastError());
DPErrcheck(gpuDeviceSynchronize());

DPErrcheck(gpuMemset(max_nbor_size, 0, sizeof(int) * int_64(nloc) * ntypes));
const int nblock_loc = (nloc * MAX_NNEI + TPB - 1) / TPB;
neighbor_stat_g<<<nblock_loc, TPB>>>(
coord, type, nloc, gpu_nlist.ilist, gpu_nlist.firstneigh,
gpu_nlist.numneigh, max_nbor_size, min_nbor_dist, ntypes, MAX_NNEI);

DPErrcheck(gpuGetLastError());
DPErrcheck(gpuDeviceSynchronize());
}

template void neighbor_stat_gpu<float>(const float* coord,
const int* type,
const int nloc,
const deepmd::InputNlist& gpu_nlist,
int* max_nbor_size,
float* min_nbor_dist,
const int ntypes,
const int MAX_NNEI);

template void neighbor_stat_gpu<double>(const double* coord,
const int* type,
const int nloc,
const deepmd::InputNlist& gpu_nlist,
int* max_nbor_size,
double* min_nbor_dist,
const int ntypes,
const int MAX_NNEI);
} // namespace deepmd
28 changes: 28 additions & 0 deletions source/op/custom_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <vector>

#include "device.h"
#include "neighbor_list.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
Expand All @@ -25,3 +26,30 @@ namespace deepmd {
void safe_compute(OpKernelContext* context,
std::function<void(OpKernelContext*)> ff);
};

template <typename FPTYPE>
void _prepare_coord_nlist_gpu(OpKernelContext* context,
Tensor* tensor_list,
FPTYPE const** coord,
FPTYPE*& coord_cpy,
int const** type,
int*& type_cpy,
int*& idx_mapping,
deepmd::InputNlist& inlist,
int*& ilist,
int*& numneigh,
int**& firstneigh,
int*& jlist,
int*& nbor_list_dev,
int& new_nall,
int& mem_cpy,
int& mem_nnei,
int& max_nbor_size,
const FPTYPE* box,
const int* mesh_tensor_data,
const int mesh_tensor_size,
const int& nloc,
const int& nei_mode,
const float& rcut_r,
const int& max_cpy_trial,
const int& max_nnei_trial);
Loading

0 comments on commit da100dc

Please sign in to comment.