Skip to content

Commit

Permalink
fix neighbor stat mixed_types input (#3304)
Browse files Browse the repository at this point in the history
It was broken by #3289, which replaced distinguished types by mixed
types but didn't change the input.

Add tests for two types.

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Feb 20, 2024
1 parent 33495d8 commit fe14402
Show file tree
Hide file tree
Showing 11 changed files with 71 additions and 57 deletions.
8 changes: 4 additions & 4 deletions deepmd/dpmodel/utils/neighbor_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,18 @@ class NeighborStat(BaseNeighborStat):
The num of atom types
rcut : float
The cut-off radius
one_type : bool, optional, default=False
mixed_type : bool, optional, default=False
Treat all types as a single type.
"""

def __init__(
self,
ntypes: int,
rcut: float,
one_type: bool = False,
mixed_type: bool = False,
) -> None:
super().__init__(ntypes, rcut, one_type)
self.op = NeighborStatOP(ntypes, rcut, not one_type)
super().__init__(ntypes, rcut, mixed_type)
self.op = NeighborStatOP(ntypes, rcut, mixed_type)

def iterator(
self, data: DeepmdDataSystem
Expand Down
6 changes: 3 additions & 3 deletions deepmd/entrypoints/neighbor_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def neighbor_stat(
system: str,
rcut: float,
type_map: List[str],
one_type: bool = False,
mixed_type: bool = False,
backend: str = "tensorflow",
**kwargs,
):
Expand All @@ -36,7 +36,7 @@ def neighbor_stat(
cutoff radius
type_map : list[str]
type map
one_type : bool, optional, default=False
mixed_type : bool, optional, default=False
treat all types as a single type
backend : str, optional, default="tensorflow"
backend to use
Expand Down Expand Up @@ -89,7 +89,7 @@ def neighbor_stat(
type_map=type_map,
)
data.get_batch()
nei = NeighborStat(data.get_ntypes(), rcut, one_type=one_type)
nei = NeighborStat(data.get_ntypes(), rcut, mixed_type=mixed_type)
min_nbor_dist, max_nbor_size = nei.get_stat(data)
log.info("min_nbor_dist: %f" % min_nbor_dist)
log.info("max_nbor_size: %s" % str(max_nbor_size))
Expand Down
1 change: 1 addition & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ def main_parser() -> argparse.ArgumentParser:
help="type map",
)
parser_neighbor_stat.add_argument(
"--mixed-type",
"--one-type",
action="store_true",
default=False,
Expand Down
14 changes: 8 additions & 6 deletions deepmd/pt/utils/neighbor_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,16 @@ def forward(
)
assert list(diff.shape) == [nframes, nloc, nall, 3]
# remove the diagonal elements
mask = torch.eye(nloc, nall, dtype=torch.bool)
mask = torch.eye(nloc, nall, dtype=torch.bool, device=diff.device)
diff[:, mask] = torch.inf
rr2 = torch.sum(torch.square(diff), dim=-1)
min_rr2, _ = torch.min(rr2, dim=-1)
# count the number of neighbors
if not self.mixed_types:
mask = rr2 < self.rcut**2
nnei = torch.zeros((nframes, nloc, self.ntypes), dtype=torch.int32)
nnei = torch.zeros(
(nframes, nloc, self.ntypes), dtype=torch.int32, device=mask.device
)
for ii in range(self.ntypes):
nnei[:, :, ii] = torch.sum(
mask & extend_atype.eq(ii)[:, None, :], dim=-1
Expand All @@ -120,18 +122,18 @@ class NeighborStat(BaseNeighborStat):
The num of atom types
rcut : float
The cut-off radius
one_type : bool, optional, default=False
mixed_type : bool, optional, default=False
Treat all types as a single type.
"""

def __init__(
self,
ntypes: int,
rcut: float,
one_type: bool = False,
mixed_type: bool = False,
) -> None:
super().__init__(ntypes, rcut, one_type)
op = NeighborStatOP(ntypes, rcut, not one_type)
super().__init__(ntypes, rcut, mixed_type)
op = NeighborStatOP(ntypes, rcut, mixed_type)
self.op = torch.jit.script(op)
self.auto_batch_size = AutoBatchSize()

Expand Down
14 changes: 7 additions & 7 deletions deepmd/tf/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def get_type_map(jdata):
return jdata["model"].get("type_map", None)


def get_nbor_stat(jdata, rcut, one_type: bool = False):
def get_nbor_stat(jdata, rcut, mixed_type: bool = False):
# it seems that DeepmdDataSystem does not need rcut
# it's not clear why there is an argument...
# max_rcut = get_rcut(jdata)
Expand Down Expand Up @@ -414,7 +414,7 @@ def get_nbor_stat(jdata, rcut, one_type: bool = False):
map_ntypes = data_ntypes
ntypes = max([map_ntypes, data_ntypes])

neistat = NeighborStat(ntypes, rcut, one_type=one_type)
neistat = NeighborStat(ntypes, rcut, mixed_type=mixed_type)

min_nbor_dist, max_nbor_size = neistat.get_stat(train_data)

Expand All @@ -430,8 +430,8 @@ def get_nbor_stat(jdata, rcut, one_type: bool = False):
return min_nbor_dist, max_nbor_size


def get_sel(jdata, rcut, one_type: bool = False):
_, max_nbor_size = get_nbor_stat(jdata, rcut, one_type=one_type)
def get_sel(jdata, rcut, mixed_type: bool = False):
_, max_nbor_size = get_nbor_stat(jdata, rcut, mixed_type=mixed_type)
return max_nbor_size


Expand Down Expand Up @@ -468,12 +468,12 @@ def wrap_up_4(xx):
return 4 * ((int(xx) + 3) // 4)


def update_one_sel(jdata, descriptor, one_type: bool = False):
def update_one_sel(jdata, descriptor, mixed_type: bool = False):
rcut = descriptor["rcut"]
tmp_sel = get_sel(
jdata,
rcut,
one_type=one_type,
mixed_type=mixed_type,
)
sel = descriptor["sel"]
if isinstance(sel, int):
Expand All @@ -493,7 +493,7 @@ def update_one_sel(jdata, descriptor, one_type: bool = False):
"not less than %d, but you set it to %d. The accuracy"
" of your model may get worse." % (ii, tt, dd)
)
if one_type:
if mixed_type:
descriptor["sel"] = sel = sum(sel)
return descriptor

Expand Down
8 changes: 4 additions & 4 deletions deepmd/tf/utils/neighbor_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,20 +168,20 @@ class NeighborStat(BaseNeighborStat):
The num of atom types
rcut
The cut-off radius
one_type : bool, optional, default=False
mixed_type : bool, optional, default=False
Treat all types as a single type.
"""

def __init__(
self,
ntypes: int,
rcut: float,
one_type: bool = False,
mixed_type: bool = False,
) -> None:
"""Constructor."""
super().__init__(ntypes, rcut, one_type)
super().__init__(ntypes, rcut, mixed_type)
self.auto_batch_size = AutoBatchSize()
self.neighbor_stat = NeighborStatOP(ntypes, rcut, not one_type)
self.neighbor_stat = NeighborStatOP(ntypes, rcut, mixed_type)
self.place_holders = {}
with tf.Graph().as_default() as sub_graph:
self.op = self.build()
Expand Down
8 changes: 4 additions & 4 deletions deepmd/utils/neighbor_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,19 @@ class NeighborStat(ABC):
The num of atom types
rcut : float
The cut-off radius
one_type : bool, optional, default=False
mixed_type : bool, optional, default=False
Treat all types as a single type.
"""

def __init__(
self,
ntypes: int,
rcut: float,
one_type: bool = False,
mixed_type: bool = False,
) -> None:
self.rcut = rcut
self.ntypes = ntypes
self.one_type = one_type
self.mixed_type = mixed_type

def get_stat(self, data: DeepmdDataSystem) -> Tuple[float, np.ndarray]:
"""Get the data statistics of the training data, including nearest nbor distance between atoms, max nbor size of atoms.
Expand All @@ -62,7 +62,7 @@ def get_stat(self, data: DeepmdDataSystem) -> Tuple[float, np.ndarray]:
An array with ntypes integers, denotes the actual achieved max sel
"""
min_nbor_dist = 100.0
max_nbor_size = np.zeros(1 if self.one_type else self.ntypes, dtype=int)
max_nbor_size = np.zeros(1 if self.mixed_type else self.ntypes, dtype=int)

for mn, dt, jj in self.iterator(data):
if np.isinf(dt):
Expand Down
13 changes: 8 additions & 5 deletions source/tests/common/dpmodel/test_neighbor_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ def tearDown(self):

def test_neighbor_stat(self):
for rcut in (0.0, 1.0, 2.0, 4.0):
for one_type in (True, False):
with self.subTest(rcut=rcut, one_type=one_type):
for mixed_type in (True, False):
with self.subTest(rcut=rcut, mixed_type=mixed_type):
rcut += 1e-3 # prevent numerical errors
min_nbor_dist, max_nbor_size = neighbor_stat(
system="system_0",
rcut=rcut,
type_map=["TYPE"],
one_type=one_type,
type_map=["TYPE", "NO_THIS_TYPE"],
mixed_type=mixed_type,
backend="numpy",
)
upper = np.ceil(rcut) + 1
Expand All @@ -58,4 +58,7 @@ def test_neighbor_stat(self):
np.logical_and(distance > 0, distance <= rcut)
)
self.assertAlmostEqual(min_nbor_dist, 1.0, 6)
self.assertEqual(max_nbor_size, [expected_neighbors])
ret = [expected_neighbors]
if not mixed_type:
ret.append(0)
np.testing.assert_array_equal(max_nbor_size, ret)
13 changes: 8 additions & 5 deletions source/tests/pt/test_neighbor_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ def tearDown(self):

def test_neighbor_stat(self):
for rcut in (0.0, 1.0, 2.0, 4.0):
for one_type in (True, False):
with self.subTest(rcut=rcut, one_type=one_type):
for mixed_type in (True, False):
with self.subTest(rcut=rcut, mixed_type=mixed_type):
rcut += 1e-3 # prevent numerical errors
min_nbor_dist, max_nbor_size = neighbor_stat(
system="system_0",
rcut=rcut,
type_map=["TYPE"],
one_type=one_type,
type_map=["TYPE", "NO_THIS_TYPE"],
mixed_type=mixed_type,
backend="pytorch",
)
upper = np.ceil(rcut) + 1
Expand All @@ -58,4 +58,7 @@ def test_neighbor_stat(self):
np.logical_and(distance > 0, distance <= rcut)
)
self.assertAlmostEqual(min_nbor_dist, 1.0, 6)
self.assertEqual(max_nbor_size, [expected_neighbors])
ret = [expected_neighbors]
if not mixed_type:
ret.append(0)
np.testing.assert_array_equal(max_nbor_size, ret)
41 changes: 23 additions & 18 deletions source/tests/tf/test_neighbor_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,26 @@ def tearDown(self):
shutil.rmtree("system_0")

def test_neighbor_stat(self):
# set rcut to 0. will cause a core dumped
# TODO: check what is wrong
for rcut in (1.0, 2.0, 4.0):
with self.subTest():
rcut += 1e-3 # prevent numerical errors
min_nbor_dist, max_nbor_size = neighbor_stat(
system="system_0", rcut=rcut, type_map=["TYPE"]
)
upper = np.ceil(rcut) + 1
X, Y, Z = np.mgrid[-upper:upper, -upper:upper, -upper:upper]
positions = np.vstack([X.ravel(), Y.ravel(), Z.ravel()]).T
# distance to (0,0,0)
distance = np.linalg.norm(positions, axis=1)
expected_neighbors = np.count_nonzero(
np.logical_and(distance > 0, distance <= rcut)
)
self.assertAlmostEqual(min_nbor_dist, 1.0, 6)
self.assertEqual(max_nbor_size, [expected_neighbors])
for rcut in (0.0, 1.0, 2.0, 4.0):
for mixed_type in (True, False):
with self.subTest(rcut=rcut, mixed_type=mixed_type):
rcut += 1e-3 # prevent numerical errors
min_nbor_dist, max_nbor_size = neighbor_stat(
system="system_0",
rcut=rcut,
type_map=["TYPE", "NO_THIS_TYPE"],
mixed_type=mixed_type,
)
upper = np.ceil(rcut) + 1
X, Y, Z = np.mgrid[-upper:upper, -upper:upper, -upper:upper]
positions = np.vstack([X.ravel(), Y.ravel(), Z.ravel()]).T
# distance to (0,0,0)
distance = np.linalg.norm(positions, axis=1)
expected_neighbors = np.count_nonzero(
np.logical_and(distance > 0, distance <= rcut)
)
self.assertAlmostEqual(min_nbor_dist, 1.0, 6)
ret = [expected_neighbors]
if not mixed_type:
ret.append(0)
np.testing.assert_array_equal(max_nbor_size, ret)
2 changes: 1 addition & 1 deletion source/tests/tf/test_virtual_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,5 +138,5 @@ def test_data_mixed_type(self):
data = DeepmdDataSystem(systems, batch_size, test_size, rcut, type_map=type_map)
data.get_batch()
# neighbor stat
nei_stat = NeighborStat(len(type_map), rcut, one_type=True)
nei_stat = NeighborStat(len(type_map), rcut, mixed_type=True)
min_nbor_dist, max_nbor_size = nei_stat.get_stat(data)

0 comments on commit fe14402

Please sign in to comment.