Skip to content

Commit

Permalink
tf: refactor neighbor stat (deepmodeling#3275)
Browse files Browse the repository at this point in the history
Fix deepmodeling#3272. Apply implementation of deepmodeling#3271 into TF. Confirm consistent
results on `examples/water`, `examples/nopbc`, and ANI-1x (deepmodeling#1624).

80x speed up:

![image](https://github.com/deepmodeling/deepmd-kit/assets/9496702/85aa1fed-e3c0-4cb6-9082-db45c9a03f9d)

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Feb 15, 2024
1 parent aae4850 commit 02080db
Show file tree
Hide file tree
Showing 4 changed files with 489 additions and 62 deletions.
276 changes: 214 additions & 62 deletions deepmd/tf/utils/neighbor_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,161 @@
import logging
from typing import (
Iterator,
Optional,
Tuple,
)

import numpy as np

from deepmd.tf.env import (
GLOBAL_NP_FLOAT_PRECISION,
GLOBAL_TF_FLOAT_PRECISION,
default_tf_session_config,
op_module,
tf,
)
from deepmd.tf.utils.batch_size import (
AutoBatchSize,
)
from deepmd.tf.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.tf.utils.parallel_op import (
ParallelOp,
from deepmd.tf.utils.nlist import (
extend_coord_with_ghosts,
)
from deepmd.tf.utils.sess import (
run_sess,
)
from deepmd.utils.neighbor_stat import NeighborStat as BaseNeighborStat

log = logging.getLogger(__name__)


class NeighborStatOP:
"""Class for getting neighbor statics data information.
Parameters
----------
ntypes
The num of atom types
rcut
The cut-off radius
distinguish_types : bool, optional
If False, treat all types as a single type.
"""

def __init__(
self,
ntypes: int,
rcut: float,
distinguish_types: bool,
) -> None:
super().__init__()
self.rcut = rcut
self.ntypes = ntypes
self.distinguish_types = distinguish_types

def build(
self,
coord: tf.Tensor,
atype: tf.Tensor,
cell: tf.Tensor,
pbc: tf.Tensor,
) -> Tuple[tf.Tensor, tf.Tensor]:
"""Calculate the nearest neighbor distance between atoms, maximum nbor size of
atoms and the output data range of the environment matrix.
Parameters
----------
coord
The coordinates of atoms.
atype
The atom types.
cell
The cell.
Returns
-------
tf.Tensor
The minimal squared distance between two atoms, in the shape of (nframes,)
tf.Tensor
The maximal number of neighbors
"""
# generated by GitHub Copilot, converted from PT codes
nframes = tf.shape(coord)[0]
coord = tf.reshape(coord, [nframes, -1, 3])
nloc = tf.shape(coord)[1]
coord = tf.reshape(coord, [nframes, nloc * 3])
extend_coord, extend_atype, _ = extend_coord_with_ghosts(
coord, atype, cell, self.rcut, pbc
)

coord1 = tf.reshape(extend_coord, [nframes, -1])
nall = tf.shape(coord1)[1] // 3
coord0 = coord1[:, : nloc * 3]
diff = (
tf.reshape(coord1, [nframes, -1, 3])[:, None, :, :]
- tf.reshape(coord0, [nframes, -1, 3])[:, :, None, :]
)
# shape of diff: nframes, nloc, nall, 3
# remove the diagonal elements
mask = tf.eye(nloc, nall, dtype=tf.bool)
# expand mask
mask = tf.tile(mask[None, :, :], [nframes, 1, 1])
# expand inf
inf_mask = tf.constant(
float("inf"), dtype=GLOBAL_TF_FLOAT_PRECISION, shape=[1, 1, 1]
)
inf_mask = tf.tile(inf_mask, [nframes, nloc, nall])
# virtual type (<0) are not counted
virtual_type_mask_i = tf.tile(tf.less(atype, 0)[:, :, None], [1, 1, nall])
virtual_type_mask_j = tf.tile(
tf.less(extend_atype, 0)[:, None, :], [1, nloc, 1]
)
mask = mask | virtual_type_mask_i | virtual_type_mask_j
rr2 = tf.reduce_sum(tf.square(diff), axis=-1)
rr2 = tf.where(mask, inf_mask, rr2)
min_rr2 = tf.reduce_min(rr2, axis=(1, 2))
# count the number of neighbors
if self.distinguish_types:
mask = rr2 < self.rcut**2
nnei = []
for ii in range(self.ntypes):
nnei.append(
tf.reduce_sum(
tf.cast(
mask & (tf.equal(extend_atype, ii))[:, None, :], tf.int32
),
axis=-1,
)
)
# shape: nframes, nloc, ntypes
nnei = tf.stack(nnei, axis=-1)
else:
mask = rr2 < self.rcut**2
# virtual types (<0) are not counted
nnei = tf.reshape(
tf.reduce_sum(
tf.cast(
mask & tf.greater_equal(extend_atype, 0)[:, None, :], tf.int32
),
axis=-1,
),
[nframes, nloc, 1],
)
# nnei: nframes, nloc, ntypes
# virtual type i (<0) are not counted
nnei = tf.where(
tf.tile(
tf.less(atype, 0)[:, :, None],
[1, 1, self.ntypes if self.distinguish_types else 1],
),
tf.zeros_like(nnei, dtype=tf.int32),
nnei,
)
max_nnei = tf.reduce_max(nnei, axis=1)
return min_rr2, max_nnei


class NeighborStat(BaseNeighborStat):
"""Class for getting training data information.
Expand All @@ -47,52 +180,48 @@ def __init__(
) -> None:
"""Constructor."""
super().__init__(ntypes, rcut, one_type)
sub_graph = tf.Graph()

def builder():
place_holders = {}
for ii in ["coord", "box"]:
place_holders[ii] = tf.placeholder(
GLOBAL_NP_FLOAT_PRECISION, [None, None], name="t_" + ii
)
place_holders["type"] = tf.placeholder(
tf.int32, [None, None], name="t_type"
)
place_holders["natoms_vec"] = tf.placeholder(
tf.int32, [self.ntypes + 2], name="t_natoms"
)
place_holders["default_mesh"] = tf.placeholder(
tf.int32, [None], name="t_mesh"
)
t_type = place_holders["type"]
t_natoms = place_holders["natoms_vec"]
if self.one_type:
# all types = 0, natoms_vec = [natoms, natoms, natoms]
t_type = tf.clip_by_value(t_type, -1, 0)
t_natoms = tf.tile(t_natoms[0:1], [3])

_max_nbor_size, _min_nbor_dist = op_module.neighbor_stat(
place_holders["coord"],
t_type,
t_natoms,
place_holders["box"],
place_holders["default_mesh"],
rcut=self.rcut,
)
place_holders["dir"] = tf.placeholder(tf.string)
_min_nbor_dist = tf.reduce_min(_min_nbor_dist)
_max_nbor_size = tf.reduce_max(_max_nbor_size, axis=0)
return place_holders, (_max_nbor_size, _min_nbor_dist, place_holders["dir"])
self.auto_batch_size = AutoBatchSize()
self.neighbor_stat = NeighborStatOP(ntypes, rcut, not one_type)
self.place_holders = {}
with tf.Graph().as_default() as sub_graph:
self.op = self.build()
self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config)

with sub_graph.as_default():
self.p = ParallelOp(builder, config=default_tf_session_config)
def build(self) -> Tuple[tf.Tensor, tf.Tensor]:
"""Build the graph.
self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config)
Returns
-------
tf.Tensor
The minimal squared distance between two atoms, in the shape of (nframes,)
tf.Tensor
The maximal number of neighbors
"""
for ii in ["coord", "box"]:
self.place_holders[ii] = tf.placeholder(
GLOBAL_NP_FLOAT_PRECISION, [None, None], name="t_" + ii
)
self.place_holders["type"] = tf.placeholder(
tf.int32, [None, None], name="t_type"
)
self.place_holders["pbc"] = tf.placeholder(tf.bool, [], name="t_pbc")
ret = self.neighbor_stat.build(
self.place_holders["coord"],
self.place_holders["type"],
self.place_holders["box"],
self.place_holders["pbc"],
)
return ret

def iterator(
self, data: DeepmdDataSystem
) -> Iterator[Tuple[np.ndarray, float, str]]:
"""Abstract method for producing data.
"""Produce data.
Parameters
----------
data
The data system
Yields
------
Expand All @@ -103,23 +232,46 @@ def iterator(
str
The directory of the data system
"""
for ii in range(len(data.system_dirs)):
for jj in data.data_systems[ii].dirs:
data_set = data.data_systems[ii]
data_set_data = data_set._load_set(jj)
minrr2, max_nnei = self.auto_batch_size.execute_all(
self._execute,
data_set_data["coord"].shape[0],
data_set.get_natoms(),
data_set_data["coord"],
data_set_data["type"],
data_set_data["box"],
data_set.pbc,
)
yield np.max(max_nnei, axis=0), np.min(minrr2), jj

def feed():
for ii in range(len(data.system_dirs)):
for jj in data.data_systems[ii].dirs:
data_set = data.data_systems[ii]._load_set(jj)
for kk in range(np.array(data_set["type"]).shape[0]):
yield {
"coord": np.array(data_set["coord"])[kk].reshape(
[-1, data.natoms[ii] * 3]
),
"type": np.array(data_set["type"])[kk].reshape(
[-1, data.natoms[ii]]
),
"natoms_vec": np.array(data.natoms_vec[ii]),
"box": np.array(data_set["box"])[kk].reshape([-1, 9]),
"default_mesh": np.array(data.default_mesh[ii]),
"dir": str(jj),
}

return self.p.generate(self.sub_sess, feed())
def _execute(
self,
coord: np.ndarray,
atype: np.ndarray,
box: Optional[np.ndarray],
pbc: bool,
):
"""Execute the operation.
Parameters
----------
coord
The coordinates of atoms.
atype
The atom types.
box
The box.
pbc
Whether the box is periodic.
"""
feed_dict = {
self.place_holders["coord"]: coord,
self.place_holders["type"]: atype,
self.place_holders["box"]: box,
self.place_holders["pbc"]: pbc,
}
minrr2, max_nnei = run_sess(self.sub_sess, self.op, feed_dict=feed_dict)
return minrr2, max_nnei
Loading

0 comments on commit 02080db

Please sign in to comment.