Skip to content

Commit

Permalink
Fix GPU UTs
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Jan 30, 2024
1 parent 4a29c8c commit fdbccab
Showing 8 changed files with 82 additions and 76 deletions.
4 changes: 2 additions & 2 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
@@ -276,13 +276,13 @@ def collate_batch(batch):
result[key] = torch.zeros(
(n_frames, natoms_extended, 3),
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.PREPROCESS_DEVICE,
device=env.DEVICE,
)
else:
result[key] = torch.zeros(
(n_frames, natoms_extended),
dtype=torch.long,
device=env.PREPROCESS_DEVICE,
device=env.DEVICE,
)
for i in range(len(batch)):
natoms_tmp = list[i].shape[0]
26 changes: 13 additions & 13 deletions deepmd/pt/utils/dataset.py
Original file line number Diff line number Diff line change
@@ -480,7 +480,7 @@ def preprocess(self, batch):
batch[kk] = torch.tensor(
batch[kk],
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.PREPROCESS_DEVICE,
device=env.DEVICE,
)
if self._data_dict[kk]["atomic"]:
batch[kk] = batch[kk].view(
@@ -490,7 +490,7 @@ def preprocess(self, batch):
for kk in ["type", "real_natoms_vec"]:
if kk in batch.keys():
batch[kk] = torch.tensor(
batch[kk], dtype=torch.long, device=env.PREPROCESS_DEVICE
batch[kk], dtype=torch.long, device=env.DEVICE
)
batch["atype"] = batch.pop("type")

@@ -526,10 +526,10 @@ def preprocess(self, batch):
batch["shift"] = torch.zeros(
(n_frames, natoms_extended, 3),
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.PREPROCESS_DEVICE,
device=env.DEVICE,
)
batch["mapping"] = torch.zeros(
(n_frames, natoms_extended), dtype=torch.long, device=env.PREPROCESS_DEVICE
(n_frames, natoms_extended), dtype=torch.long, device=env.DEVICE
)
for i in range(len(shift)):
natoms_tmp = shift[i].shape[0]
@@ -568,14 +568,14 @@ def single_preprocess(self, batch, sid):
batch[kk] = torch.tensor(
batch[kk][sid],
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.PREPROCESS_DEVICE,
device=env.DEVICE,
)
if self._data_dict[kk]["atomic"]:
batch[kk] = batch[kk].view(-1, self._data_dict[kk]["ndof"])
for kk in ["type", "real_natoms_vec"]:
if kk in batch.keys():
batch[kk] = torch.tensor(
batch[kk][sid], dtype=torch.long, device=env.PREPROCESS_DEVICE
batch[kk][sid], dtype=torch.long, device=env.DEVICE
)
clean_coord = batch.pop("coord")
clean_type = batch.pop("type")
@@ -671,29 +671,29 @@ def single_preprocess(self, batch, sid):
noised_coord = _clean_coord.clone().detach()
noised_coord[coord_mask] += noise_on_coord
batch["coord_mask"] = torch.tensor(
coord_mask, dtype=torch.bool, device=env.PREPROCESS_DEVICE
coord_mask, dtype=torch.bool, device=env.DEVICE
)
else:
noised_coord = _clean_coord
batch["coord_mask"] = torch.tensor(
np.zeros_like(coord_mask, dtype=bool),
dtype=torch.bool,
device=env.PREPROCESS_DEVICE,
device=env.DEVICE,
)

# add mask for type
if self.mask_type:
masked_type = clean_type.clone().detach()
masked_type[type_mask] = self.mask_type_idx
batch["type_mask"] = torch.tensor(
type_mask, dtype=torch.bool, device=env.PREPROCESS_DEVICE
type_mask, dtype=torch.bool, device=env.DEVICE
)
else:
masked_type = clean_type
batch["type_mask"] = torch.tensor(
np.zeros_like(type_mask, dtype=bool),
dtype=torch.bool,
device=env.PREPROCESS_DEVICE,
device=env.DEVICE,
)
if self.pbc:
_coord = normalize_coord(noised_coord, region, nloc)
@@ -803,7 +803,7 @@ def __len__(self):
def __getitem__(self, index):
"""Get a frame from the selected system."""
b_data = self._data_system._get_item(index)
b_data["natoms"] = torch.tensor(self._natoms_vec, device=env.PREPROCESS_DEVICE)
b_data["natoms"] = torch.tensor(self._natoms_vec, device=env.DEVICE)
return b_data


@@ -879,7 +879,7 @@ def __getitem__(self, index=None):
index = dp_random.choice(np.arange(self.nsystems), p=self.probs)
b_data = self._data_systems[index].get_batch(self._batch_size)
b_data["natoms"] = torch.tensor(
self._natoms_vec[index], device=env.PREPROCESS_DEVICE
self._natoms_vec[index], device=env.DEVICE
)
batch_size = b_data["coord"].shape[0]
b_data["natoms"] = b_data["natoms"].unsqueeze(0).expand(batch_size, -1)
@@ -892,7 +892,7 @@ def get_training_batch(self, index=None):
index = dp_random.choice(np.arange(self.nsystems), p=self.probs)
b_data = self._data_systems[index].get_batch_for_train(self._batch_size)
b_data["natoms"] = torch.tensor(
self._natoms_vec[index], device=env.PREPROCESS_DEVICE
self._natoms_vec[index], device=env.DEVICE
)
batch_size = b_data["coord"].shape[0]
b_data["natoms"] = b_data["natoms"].unsqueeze(0).expand(batch_size, -1)
30 changes: 15 additions & 15 deletions deepmd/pt/utils/preprocess.py
Original file line number Diff line number Diff line change
@@ -99,7 +99,7 @@ def build_inside_clist(coord, region: Region3D, ncell):
cell_offset[cell_offset < 0] = 0
delta = cell_offset - ncell
a2c = compute_serial_cid(cell_offset, ncell) # cell id of atoms
arange = torch.arange(0, loc_ncell, 1, device=env.PREPROCESS_DEVICE)
arange = torch.arange(0, loc_ncell, 1, device=env.DEVICE)
cellid = a2c == arange.unsqueeze(-1) # one hot cellid
c2a = cellid.nonzero()
lst = []
@@ -131,17 +131,17 @@ def append_neighbors(coord, region: Region3D, atype, rcut: float):

# add ghost atoms
a2c, c2a = build_inside_clist(coord, region, ncell)
xi = torch.arange(-ngcell[0], ncell[0] + ngcell[0], 1, device=env.PREPROCESS_DEVICE)
yi = torch.arange(-ngcell[1], ncell[1] + ngcell[1], 1, device=env.PREPROCESS_DEVICE)
zi = torch.arange(-ngcell[2], ncell[2] + ngcell[2], 1, device=env.PREPROCESS_DEVICE)
xi = torch.arange(-ngcell[0], ncell[0] + ngcell[0], 1, device=env.DEVICE)
yi = torch.arange(-ngcell[1], ncell[1] + ngcell[1], 1, device=env.DEVICE)
zi = torch.arange(-ngcell[2], ncell[2] + ngcell[2], 1, device=env.DEVICE)
xyz = xi.view(-1, 1, 1, 1) * torch.tensor(
[1, 0, 0], dtype=torch.long, device=env.PREPROCESS_DEVICE
[1, 0, 0], dtype=torch.long, device=env.DEVICE
)
xyz = xyz + yi.view(1, -1, 1, 1) * torch.tensor(
[0, 1, 0], dtype=torch.long, device=env.PREPROCESS_DEVICE
[0, 1, 0], dtype=torch.long, device=env.DEVICE
)
xyz = xyz + zi.view(1, 1, -1, 1) * torch.tensor(
[0, 0, 1], dtype=torch.long, device=env.PREPROCESS_DEVICE
[0, 0, 1], dtype=torch.long, device=env.DEVICE
)
xyz = xyz.view(-1, 3)
mask_a = (xyz >= 0).all(dim=-1)
@@ -166,7 +166,7 @@ def append_neighbors(coord, region: Region3D, atype, rcut: float):
merged_coord_shift = torch.cat([torch.zeros_like(coord), coord_shift[tmp]])
merged_atype = torch.cat([atype, tmp_atype])
merged_mapping = torch.cat(
[torch.arange(atype.numel(), device=env.PREPROCESS_DEVICE), aid]
[torch.arange(atype.numel(), device=env.DEVICE), aid]
)
return merged_coord_shift, merged_atype, merged_mapping

@@ -189,20 +189,20 @@ def build_neighbor_list(
distance = torch.linalg.norm(distance, dim=-1)
DISTANCE_INF = distance.max().detach() + rcut
distance[:nloc, :nloc] += (
torch.eye(nloc, dtype=torch.bool, device=env.PREPROCESS_DEVICE) * DISTANCE_INF
torch.eye(nloc, dtype=torch.bool, device=env.DEVICE) * DISTANCE_INF
)
if min_check:
if distance.min().abs() < 1e-6:
RuntimeError("Atom dist too close!")
if not type_split:
sec = sec[-1:]
lst = []
nlist = torch.zeros((nloc, sec[-1].item()), device=env.PREPROCESS_DEVICE).long() - 1
nlist = torch.zeros((nloc, sec[-1].item()), device=env.DEVICE).long() - 1
nlist_loc = (
torch.zeros((nloc, sec[-1].item()), device=env.PREPROCESS_DEVICE).long() - 1
torch.zeros((nloc, sec[-1].item()), device=env.DEVICE).long() - 1
)
nlist_type = (
torch.zeros((nloc, sec[-1].item()), device=env.PREPROCESS_DEVICE).long() - 1
torch.zeros((nloc, sec[-1].item()), device=env.DEVICE).long() - 1
)
for i, nnei in enumerate(sec):
if i > 0:
@@ -216,9 +216,9 @@ def build_neighbor_list(
_sorted, indices = torch.topk(tmp, nnei, dim=1, largest=False)
else:
# when nnei > nall
indices = torch.zeros((nloc, nnei), device=env.PREPROCESS_DEVICE).long() - 1
indices = torch.zeros((nloc, nnei), device=env.DEVICE).long() - 1
_sorted = (
torch.ones((nloc, nnei), device=env.PREPROCESS_DEVICE).long()
torch.ones((nloc, nnei), device=env.DEVICE).long()
* DISTANCE_INF
)
_sorted_nnei, indices_nnei = torch.topk(
@@ -284,7 +284,7 @@ def make_env_mat(
else:
merged_coord_shift = torch.zeros_like(coord)
merged_atype = atype.clone()
merged_mapping = torch.arange(atype.numel(), device=env.PREPROCESS_DEVICE)
merged_mapping = torch.arange(atype.numel(), device=env.DEVICE)
merged_coord = coord.clone()

# build nlist
4 changes: 2 additions & 2 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
@@ -62,13 +62,13 @@ def make_stat_input(datasets, dataloaders, nbatches):
shape = torch.zeros(
(n_frames, extend, 3),
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.PREPROCESS_DEVICE,
device=env.DEVICE,
)
else:
shape = torch.zeros(
(n_frames, extend),
dtype=torch.long,
device=env.PREPROCESS_DEVICE,
device=env.DEVICE,
)
for i in range(len(item)):
natoms_tmp = l[i].shape[0]
11 changes: 7 additions & 4 deletions source/tests/pt/test_descriptor.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,9 @@
from pathlib import (
Path,
)
from deepmd.pt.utils import (
env,
)

from deepmd.pt.model.descriptor import (
prod_env_mat_se_a,
@@ -112,18 +115,18 @@ def setUp(self):

def test_consistency(self):
avg_zero = torch.zeros(
[self.ntypes, self.nnei * 4], dtype=GLOBAL_PT_FLOAT_PRECISION
[self.ntypes, self.nnei * 4], dtype=GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
std_ones = torch.ones(
[self.ntypes, self.nnei * 4], dtype=GLOBAL_PT_FLOAT_PRECISION
[self.ntypes, self.nnei * 4], dtype=GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
base_d, base_force, nlist = base_se_a(
rcut=self.rcut,
rcut_smth=self.rcut_smth,
sel=self.sel,
batch=self.np_batch,
mean=avg_zero,
stddev=std_ones,
mean=avg_zero.detach().cpu(),
stddev=std_ones.detach().cpu(),
)

pt_coord = self.pt_batch["coord"]
8 changes: 4 additions & 4 deletions source/tests/pt/test_descriptor_dpa1.py
Original file line number Diff line number Diff line change
@@ -243,7 +243,7 @@ def test_descriptor_block(self):
dparams["ntypes"] = ntypes
des = DescrptBlockSeAtten(
**dparams,
)
).to(env.DEVICE)
des.load_state_dict(torch.load(self.file_model_param))
rcut = dparams["rcut"]
nsel = dparams["sel"]
@@ -260,7 +260,7 @@ def test_descriptor_block(self):
extended_coord, extended_atype, nloc, rcut, nsel, distinguish_types=False
)
# handel type_embedding
type_embedding = TypeEmbedNet(ntypes, 8)
type_embedding = TypeEmbedNet(ntypes, 8).to(env.DEVICE)
type_embedding.load_state_dict(torch.load(self.file_type_embed))

## to save model parameters
@@ -293,7 +293,7 @@ def test_descriptor(self):
dparams["concat_output_tebd"] = False
des = DescrptDPA1(
**dparams,
)
).to(env.DEVICE)
target_dict = des.state_dict()
source_dict = torch.load(self.file_model_param)
type_embd_dict = torch.load(self.file_type_embed)
@@ -337,7 +337,7 @@ def test_descriptor(self):
dparams["concat_output_tebd"] = True
des = DescrptDPA1(
**dparams,
)
).to(env.DEVICE)
descriptor, env_mat, diff, rot_mat, sw = des(
extended_coord,
extended_atype,
8 changes: 4 additions & 4 deletions source/tests/pt/test_descriptor_dpa2.py
Original file line number Diff line number Diff line change
@@ -124,7 +124,7 @@ def test_descriptor_hyb(self):
dlist,
ntypes,
hybrid_mode=dparams["hybrid_mode"],
)
).to(env.DEVICE)
model_dict = torch.load(self.file_model_param)
# type_embd of repformer is removed
model_dict.pop("descriptor_list.1.type_embd.embedding.weight")
@@ -158,7 +158,7 @@ def test_descriptor_hyb(self):
)
nlist = torch.cat(nlist_list, -1)
# handel type_embedding
type_embedding = TypeEmbedNet(ntypes, 8)
type_embedding = TypeEmbedNet(ntypes, 8).to(env.DEVICE)
type_embedding.load_state_dict(torch.load(self.file_type_embed))

## to save model parameters
@@ -186,7 +186,7 @@ def test_descriptor(self):
dparams["concat_output_tebd"] = False
des = DescrptDPA2(
**dparams,
)
).to(env.DEVICE)
target_dict = des.state_dict()
source_dict = torch.load(self.file_model_param)
# type_embd of repformer is removed
@@ -232,7 +232,7 @@ def test_descriptor(self):
dparams["concat_output_tebd"] = True
des = DescrptDPA2(
**dparams,
)
).to(env.DEVICE)
descriptor, env_mat, diff, rot_mat, sw = des(
extended_coord,
extended_atype,
Loading

0 comments on commit fdbccab

Please sign in to comment.