Skip to content

Commit

Permalink
make the pairwise DPRc model 2x faster (#2833)
Browse files Browse the repository at this point in the history
This PR does a trick to speed up the pairwise DPRc model. Considering
#2618 is not ready and is quite difficult to implement, in this PR,
multiple frames are merged into one frame before feed to `prod_env_mat`
OP, and the mesh is faked to make it perform the same behavior as the
multiple frames.
A new `mesh` shape is proposed. The first element stores `nloc`, and the
following 15 elements store nothing to distinguish it from other mesh.
The `(16 : 16 + nloc)` elements store `ilist`, `(16 + nloc : 16 + nloc *
2)` store `numneigh`, and the rest elements (in the shape of
`sum(numneigh)`) store neighbors. The `nei_mode` is 4 for this
situation.

`prod_env_mat` OP is not a bottleneck anymore, as shown below.

![image](https://github.com/deepmodeling/deepmd-kit/assets/9496702/eea64b99-d630-4ea1-99f4-e7d49c126c33)

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Sep 19, 2023
1 parent ce7532a commit a735bed
Show file tree
Hide file tree
Showing 5 changed files with 526 additions and 27 deletions.
55 changes: 38 additions & 17 deletions deepmd/model/pairwise_dprc.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def __init__(
compress: Optional[dict] = None,
**kwargs,
) -> None:
# internal variable to compare old and new behavior
# expect they give the same results
self.merge_frames = True

super().__init__(
type_embedding=type_embedding,
type_map=type_map,
Expand Down Expand Up @@ -151,16 +155,27 @@ def build(
atype = tf.reshape(atype_, [nframes, natoms[1], 1])
nframes_qmmm = tf.shape(qmmm_frame_idx)[0]

if self.merge_frames:
(
forward_qmmm_map,
backward_qmmm_map,
natoms_qmmm,
mesh_qmmm,
) = op_module.convert_forward_map(forward_qmmm_map, natoms_qmmm, natoms)
coord_qmmm = tf.reshape(coord, [1, -1, 3])
atype_qmmm = tf.reshape(atype, [1, -1, 1])
box_qmmm = tf.reshape(box[0], [1, 9])
else:
mesh_qmmm = make_default_mesh(False, True)
coord_qmmm = tf.gather(coord, qmmm_frame_idx)
atype_qmmm = tf.gather(atype, qmmm_frame_idx)
box_qmmm = tf.gather(box, qmmm_frame_idx)

coord_qm = gather_placeholder(coord, forward_qm_map)
atype_qm = gather_placeholder(atype, forward_qm_map, placeholder=-1)
coord_qmmm = gather_placeholder(
tf.gather(coord, qmmm_frame_idx), forward_qmmm_map
)
atype_qmmm = gather_placeholder(
tf.gather(atype, qmmm_frame_idx), forward_qmmm_map, placeholder=-1
)
coord_qmmm = gather_placeholder(coord_qmmm, forward_qmmm_map)
atype_qmmm = gather_placeholder(atype_qmmm, forward_qmmm_map, placeholder=-1)
box_qm = box
box_qmmm = tf.gather(box, qmmm_frame_idx)

type_embedding = self.typeebd.build(
self.ntypes,
Expand Down Expand Up @@ -189,18 +204,22 @@ def build(
atype_qmmm,
natoms_qmmm,
box_qmmm,
mesh_mixed_type,
mesh_qmmm,
input_dict_qmmm,
frz_model=frz_model,
ckpt_meta=ckpt_meta,
suffix="_qmmm" + suffix,
reuse=reuse,
)

energy_qm = qm_dict["energy"]
energy_qmmm = tf.math.segment_sum(qmmm_dict["energy"], qmmm_frame_idx)
energy = energy_qm + energy_qmmm
energy = tf.identity(energy, name="o_energy" + suffix)
if self.merge_frames:
qmmm_dict = qmmm_dict.copy()
sub_nframes = tf.shape(backward_qmmm_map)[0]
qmmm_dict["force"] = tf.tile(qmmm_dict["force"], [sub_nframes, 1])
qmmm_dict["atom_ener"] = tf.tile(qmmm_dict["atom_ener"], [sub_nframes, 1])
qmmm_dict["atom_virial"] = tf.tile(
qmmm_dict["atom_virial"], [sub_nframes, 1]
)

force_qm = gather_placeholder(
tf.reshape(qm_dict["force"], (nframes, natoms_qm[1], 3)),
Expand All @@ -218,11 +237,6 @@ def build(
force = force_qm + force_qmmm
force = tf.reshape(force, (nframes, 3 * natoms[1]), name="o_force" + suffix)

virial_qm = qm_dict["virial"]
virial_qmmm = tf.math.segment_sum(qmmm_dict["virial"], qmmm_frame_idx)
virial = virial_qm + virial_qmmm
virial = tf.identity(virial, name="o_virial" + suffix)

backward_qm_map_nloc = tf.slice(backward_qm_map, [0, 0], [-1, natoms[0]])
backward_qmmm_map_nloc = tf.slice(backward_qmmm_map, [0, 0], [-1, natoms[0]])
atom_ener_qm = gather_placeholder(
Expand Down Expand Up @@ -255,6 +269,13 @@ def build(
atom_virial, (nframes, 9 * natoms[1]), name="o_atom_virial" + suffix
)

energy = tf.reduce_sum(atom_ener, axis=1, name="o_energy" + suffix)
virial = tf.reduce_sum(
tf.reshape(atom_virial, (nframes, natoms[1], 9)),
axis=1,
name="o_virial" + suffix,
)

model_dict = {}
model_dict["energy"] = energy
model_dict["force"] = force
Expand Down
2 changes: 1 addition & 1 deletion source/lib/include/gpu_rocm.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ void memcpy_host_to_device(FPTYPE *device, const FPTYPE *host, const int size) {
}

template <typename FPTYPE>
void memcpy_device_to_host(FPTYPE *device, std::vector<FPTYPE> &host) {
void memcpy_device_to_host(const FPTYPE *device, std::vector<FPTYPE> &host) {
DPErrcheck(hipMemcpy(&host[0], device, sizeof(FPTYPE) * host.size(),
hipMemcpyDeviceToHost));
}
Expand Down
200 changes: 197 additions & 3 deletions source/op/pairwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ REGISTER_OP("DprcPairwiseIdx")
.Output("natoms_qmmm: int32")
.Output("qmmm_frame_idx: int32");

REGISTER_OP("ConvertForwardMap")
.Input("sub_forward_map: int32")
.Input("sub_natoms: int32")
.Input("natoms: int32")
.Output("forward_map: int32")
.Output("backward_map: int32")
.Output("new_natoms: int32")
.Output("mesh: int32");

using namespace tensorflow;

using CPUDevice = Eigen::ThreadPoolDevice;
Expand Down Expand Up @@ -208,8 +217,193 @@ class PairwiseIdxOp : public OpKernel {
}
};

template <typename Device>
class ConvertForwardMapOp : public OpKernel {
public:
explicit ConvertForwardMapOp(OpKernelConstruction* context)
: OpKernel(context) {}

void Compute(OpKernelContext* context) override {
deepmd::safe_compute(
context, [this](OpKernelContext* context) { this->_Compute(context); });
}

void _Compute(OpKernelContext* context) {
// Grab the input tensor
int tmp_idx = 0;
const Tensor& sub_forward_map_tensor = context->input(tmp_idx++);
const Tensor& sub_natoms_tensor = context->input(tmp_idx++);
const Tensor& natoms_tensor = context->input(tmp_idx++);

// set size of the sample
OP_REQUIRES(context, (sub_forward_map_tensor.shape().dims() == 2),
errors::InvalidArgument("Dim of idxs should be 2"));
OP_REQUIRES(context, (natoms_tensor.shape().dims() == 1),
errors::InvalidArgument("Dim of natoms should be 1"));

auto sub_forward_map = sub_forward_map_tensor.matrix<int>();
int sub_nframes = sub_forward_map_tensor.shape().dim_size(0);
auto sub_natoms = sub_natoms_tensor.vec<int>();
auto natoms = natoms_tensor.vec<int>();
int sub_nloc = sub_natoms(0);
int sub_nall = sub_natoms(1);
int nloc = natoms(0);
int nall = natoms(1);

// merge multiple sub-frames into one frame
// firstly, we need to get the nloc and nghost size to allocate
int new_nloc = 0, new_nghost = 0;

for (int ii = 0; ii < sub_nframes; ++ii) {
for (int jj = 0; jj < sub_nloc; ++jj) {
if (sub_forward_map(ii, jj) != -1) {
new_nloc++;
}
}
for (int jj = sub_nloc; jj < sub_nall; ++jj) {
if (sub_forward_map(ii, jj) != -1) {
new_nghost++;
}
}
}
if (new_nloc == 0) {
new_nloc = 1;
}
int new_nall = new_nloc + new_nghost;

// Create an output tensor
TensorShape forward_map_shape;
forward_map_shape.AddDim(1);
forward_map_shape.AddDim(new_nall);
TensorShape backward_map_shape;
// since the atom index can not be repeated, we still need
// to split to multiple frames
backward_map_shape.AddDim(sub_nframes);
backward_map_shape.AddDim(nall);
TensorShape new_natoms_shape;
new_natoms_shape.AddDim(natoms_tensor.shape().dim_size(0));

Tensor* forward_map_tensor = NULL;
Tensor* backward_map_tensor = NULL;
Tensor* new_natoms_tensor = NULL;
tmp_idx = 0;
OP_REQUIRES_OK(context,
context->allocate_output(tmp_idx++, forward_map_shape,
&forward_map_tensor));
OP_REQUIRES_OK(context,
context->allocate_output(tmp_idx++, backward_map_shape,
&backward_map_tensor));
OP_REQUIRES_OK(context,
context->allocate_output(tmp_idx++, new_natoms_shape,
&new_natoms_tensor));

auto forward_map = forward_map_tensor->matrix<int>();
auto backward_map = backward_map_tensor->matrix<int>();
auto new_natoms = new_natoms_tensor->vec<int>();

// fill -1 in backward_map_tensor
for (int ii = 0; ii < sub_nframes; ++ii) {
for (int jj = 0; jj < nall; ++jj) {
backward_map(ii, jj) = -1;
}
}

std::vector<int> start_kk(sub_nframes),
end_kk(sub_nframes); // current forward map index
int kk = 0;
// assume nlist to contain all atoms; it should not be a problem for small
// residues
std::vector<std::vector<int>> jlist(new_nloc);
for (int ii = 0; ii < sub_nframes; ++ii) {
start_kk[ii] = kk;
for (int jj = 0; jj < sub_nloc; ++jj) {
if (sub_forward_map(ii, jj) != -1) {
forward_map(0, kk) = sub_forward_map(ii, jj);
backward_map(ii, sub_forward_map(ii, jj)) = kk;
kk++;
}
}
end_kk[ii] = kk;
// add neighbors to each other
for (int mm = start_kk[ii]; mm < end_kk[ii]; ++mm) {
for (int nn = start_kk[ii]; nn < end_kk[ii]; ++nn) {
if (mm != nn) {
jlist[mm].push_back(nn);
}
}
}
}
for (int ii = 0; ii < sub_nframes; ++ii) {
int start_ghost_kk = kk;
for (int jj = sub_nloc; jj < sub_nall; ++jj) {
if (sub_forward_map(ii, jj) != -1) {
forward_map(0, kk) = sub_forward_map(ii, jj);
backward_map(ii, sub_forward_map(ii, jj)) = kk;
kk++;
}
}
int end_ghost_kk = kk;
// add ghost neighbors to real atoms
for (int mm = start_kk[ii]; mm < end_kk[ii]; ++mm) {
for (int nn = start_ghost_kk; nn < end_ghost_kk; ++nn) {
jlist[mm].push_back(nn);
}
}
}

// natoms
new_natoms(0) = new_nloc;
new_natoms(1) = new_nall;
new_natoms(2) = new_nloc;
for (int ii = 3; ii < new_natoms.size(); ++ii) {
new_natoms(ii) = 0;
}

// mesh:
// first element: nloc (a number)
// 2~16: empty (to distinguish from other mesh)
// ilist: nloc
// numneigh: nloc
// jlist: sum(numneigh)

// calculate numneigh
std::vector<int> numneigh(new_nloc);
for (int ii = 0; ii < new_nloc; ++ii) {
numneigh[ii] = jlist[ii].size();
}
int size_mesh =
std::accumulate(numneigh.begin(), numneigh.end(), 2 * new_nloc + 16);

TensorShape mesh_shape;
mesh_shape.AddDim(size_mesh);
Tensor* mesh_tensor = NULL;
OP_REQUIRES_OK(
context, context->allocate_output(tmp_idx++, mesh_shape, &mesh_tensor));
auto mesh = mesh_tensor->vec<int>();
mesh(0) = new_nloc;
for (int ii = 1; ii < 16; ++ii) {
mesh(ii) = 0;
}
for (int ii = 0; ii < new_nloc; ++ii) {
mesh(ii + 16) = ii;
}
for (int ii = 0; ii < new_nloc; ++ii) {
mesh(ii + 16 + new_nloc) = numneigh[ii];
}
kk = 0;
for (int ii = 0; ii < new_nloc; ++ii) {
for (int jj = 0; jj < numneigh[ii]; ++jj) {
mesh(16 + 2 * new_nloc + kk) = jlist[ii][jj];
kk++;
}
}
}
};

// Register the CPU kernels.
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER(Name("DprcPairwiseIdx").Device(DEVICE_CPU), \
PairwiseIdxOp<CPUDevice>);
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER(Name("DprcPairwiseIdx").Device(DEVICE_CPU), \
PairwiseIdxOp<CPUDevice>); \
REGISTER_KERNEL_BUILDER(Name("ConvertForwardMap").Device(DEVICE_CPU), \
ConvertForwardMapOp<CPUDevice>);
REGISTER_CPU();
Loading

0 comments on commit a735bed

Please sign in to comment.