Skip to content

Commit

Permalink
Merge branch 'fix_gpu_ut' of https://github.com/iProzd/deepmd-kit int…
Browse files Browse the repository at this point in the history
…o fix_gpu_ut
  • Loading branch information
iProzd committed Jan 30, 2024
2 parents a4892b7 + 913efa0 commit 06d2579
Showing 5 changed files with 30 additions and 44 deletions.
33 changes: 9 additions & 24 deletions deepmd/pt/utils/dataset.py
Original file line number Diff line number Diff line change
@@ -477,10 +477,7 @@ def preprocess(self, batch):
if "find_" in kk:
pass
else:
batch[kk] = torch.tensor(
batch[kk],
dtype=env.GLOBAL_PT_FLOAT_PRECISION
)
batch[kk] = torch.tensor(batch[kk], dtype=env.GLOBAL_PT_FLOAT_PRECISION)
if self._data_dict[kk]["atomic"]:
batch[kk] = batch[kk].view(
n_frames, -1, self._data_dict[kk]["ndof"]
@@ -521,12 +518,9 @@ def preprocess(self, batch):
batch["nlist_type"] = nlist_type
natoms_extended = max([item.shape[0] for item in shift])
batch["shift"] = torch.zeros(
(n_frames, natoms_extended, 3),
dtype=env.GLOBAL_PT_FLOAT_PRECISION
)
batch["mapping"] = torch.zeros(
(n_frames, natoms_extended), dtype=torch.long
(n_frames, natoms_extended, 3), dtype=env.GLOBAL_PT_FLOAT_PRECISION
)
batch["mapping"] = torch.zeros((n_frames, natoms_extended), dtype=torch.long)
for i in range(len(shift)):
natoms_tmp = shift[i].shape[0]
batch["shift"][i, :natoms_tmp] = shift[i]
@@ -562,16 +556,13 @@ def single_preprocess(self, batch, sid):
pass
else:
batch[kk] = torch.tensor(
batch[kk][sid],
dtype=env.GLOBAL_PT_FLOAT_PRECISION
batch[kk][sid], dtype=env.GLOBAL_PT_FLOAT_PRECISION
)
if self._data_dict[kk]["atomic"]:
batch[kk] = batch[kk].view(-1, self._data_dict[kk]["ndof"])
for kk in ["type", "real_natoms_vec"]:
if kk in batch.keys():
batch[kk] = torch.tensor(
batch[kk][sid], dtype=torch.long
)
batch[kk] = torch.tensor(batch[kk][sid], dtype=torch.long)
clean_coord = batch.pop("coord")
clean_type = batch.pop("type")
nloc = clean_type.shape[0]
@@ -665,28 +656,22 @@ def single_preprocess(self, batch, sid):
NotImplementedError(f"Unknown noise type {self.noise_type}!")
noised_coord = _clean_coord.clone().detach()
noised_coord[coord_mask] += noise_on_coord
batch["coord_mask"] = torch.tensor(
coord_mask, dtype=torch.bool
)
batch["coord_mask"] = torch.tensor(coord_mask, dtype=torch.bool)
else:
noised_coord = _clean_coord
batch["coord_mask"] = torch.tensor(
np.zeros_like(coord_mask, dtype=bool),
dtype=torch.bool
np.zeros_like(coord_mask, dtype=bool), dtype=torch.bool
)

# add mask for type
if self.mask_type:
masked_type = clean_type.clone().detach()
masked_type[type_mask] = self.mask_type_idx
batch["type_mask"] = torch.tensor(
type_mask, dtype=torch.bool
)
batch["type_mask"] = torch.tensor(type_mask, dtype=torch.bool)
else:
masked_type = clean_type
batch["type_mask"] = torch.tensor(
np.zeros_like(type_mask, dtype=bool),
dtype=torch.bool
np.zeros_like(type_mask, dtype=bool), dtype=torch.bool
)
if self.pbc:
_coord = normalize_coord(noised_coord, region, nloc)
16 changes: 4 additions & 12 deletions deepmd/pt/utils/preprocess.py
Original file line number Diff line number Diff line change
@@ -134,15 +134,9 @@ def append_neighbors(coord, region: Region3D, atype, rcut: float):
xi = torch.arange(-ngcell[0], ncell[0] + ngcell[0], 1)
yi = torch.arange(-ngcell[1], ncell[1] + ngcell[1], 1)
zi = torch.arange(-ngcell[2], ncell[2] + ngcell[2], 1)
xyz = xi.view(-1, 1, 1, 1) * torch.tensor(
[1, 0, 0], dtype=torch.long
)
xyz = xyz + yi.view(1, -1, 1, 1) * torch.tensor(
[0, 1, 0], dtype=torch.long
)
xyz = xyz + zi.view(1, 1, -1, 1) * torch.tensor(
[0, 0, 1], dtype=torch.long
)
xyz = xi.view(-1, 1, 1, 1) * torch.tensor([1, 0, 0], dtype=torch.long)
xyz = xyz + yi.view(1, -1, 1, 1) * torch.tensor([0, 1, 0], dtype=torch.long)
xyz = xyz + zi.view(1, 1, -1, 1) * torch.tensor([0, 0, 1], dtype=torch.long)
xyz = xyz.view(-1, 3)
mask_a = (xyz >= 0).all(dim=-1)
mask_b = (xyz < ncell).all(dim=-1)
@@ -186,9 +180,7 @@ def build_neighbor_list(
distance = coord_l - coord_r
distance = torch.linalg.norm(distance, dim=-1)
DISTANCE_INF = distance.max().detach() + rcut
distance[:nloc, :nloc] += (
torch.eye(nloc, dtype=torch.bool) * DISTANCE_INF
)
distance[:nloc, :nloc] += torch.eye(nloc, dtype=torch.bool) * DISTANCE_INF
if min_check:
if distance.min().abs() < 1e-6:
RuntimeError("Atom dist too close!")
5 changes: 1 addition & 4 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
@@ -64,10 +64,7 @@ def make_stat_input(datasets, dataloaders, nbatches):
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
)
else:
shape = torch.zeros(
(n_frames, extend),
dtype=torch.long
)
shape = torch.zeros((n_frames, extend), dtype=torch.long)
for i in range(len(item)):
natoms_tmp = l[i].shape[0]
shape[i, :natoms_tmp] = l[i]
9 changes: 7 additions & 2 deletions source/tests/pt/test_embedding_net.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
import numpy as np
import tensorflow.compat.v1 as tf
import torch

from deepmd.pt.utils import (
env,
)
@@ -153,11 +154,15 @@ def test_consistency(self):

pt_coord = self.torch_batch["coord"].to(env.DEVICE)
pt_coord.requires_grad_(True)
index = self.torch_batch["mapping"].unsqueeze(-1).expand(-1, -1, 3).to(env.DEVICE)
index = (
self.torch_batch["mapping"].unsqueeze(-1).expand(-1, -1, 3).to(env.DEVICE)
)
extended_coord = torch.gather(pt_coord, dim=1, index=index)
extended_coord = extended_coord - self.torch_batch["shift"].to(env.DEVICE)
extended_atype = torch.gather(
self.torch_batch["atype"].to(env.DEVICE), dim=1, index=self.torch_batch["mapping"].to(env.DEVICE)
self.torch_batch["atype"].to(env.DEVICE),
dim=1,
index=self.torch_batch["mapping"].to(env.DEVICE),
)
descriptor_out, _, _, _, _ = descriptor(
extended_coord,
11 changes: 9 additions & 2 deletions source/tests/pt/test_model.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
import numpy as np
import tensorflow.compat.v1 as tf
import torch

from deepmd.pt.utils import (
env,
)
@@ -343,10 +344,16 @@ def test_consistency(self):
batch["natoms_vec"], device=batch["coord"].device
).unsqueeze(0)
model_predict = my_model(
batch["coord"].to(env.DEVICE), batch["atype"].to(env.DEVICE), batch["box"].to(env.DEVICE), do_atomic_virial=True
batch["coord"].to(env.DEVICE),
batch["atype"].to(env.DEVICE),
batch["box"].to(env.DEVICE),
do_atomic_virial=True,
)
model_predict_1 = my_model(
batch["coord"].to(env.DEVICE), batch["atype"].to(env.DEVICE), batch["box"].to(env.DEVICE), do_atomic_virial=False
batch["coord"].to(env.DEVICE),
batch["atype"].to(env.DEVICE),
batch["box"].to(env.DEVICE),
do_atomic_virial=False,
)
p_energy, p_force, p_virial, p_atomic_virial = (
model_predict["energy"],

0 comments on commit 06d2579

Please sign in to comment.