Skip to content

Commit

Permalink
Remove data preprocess dependency from data stat
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 11, 2024
1 parent dd1a7a8 commit 0bcd55c
Show file tree
Hide file tree
Showing 18 changed files with 272 additions and 499 deletions.
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def distinguish_types(self) -> bool:
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return False
return self.se_atten.distinguish_types()

@property
def dim_out(self):
Expand Down
8 changes: 8 additions & 0 deletions deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ def get_dim_in(self) -> int:
def get_dim_emb(self):
return self.dim_emb

def distinguish_types(self) -> bool:
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return True in [
descriptor.distinguish_types() for descriptor in self.descriptor_list
]

@property
def dim_out(self):
"""Returns the output dimension of this descriptor."""
Expand Down
45 changes: 24 additions & 21 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
env,
)
from deepmd.pt.utils.nlist import (
build_neighbor_list,
process_input,
)
from deepmd.pt.utils.utils import (
get_activation_fn,
Expand Down Expand Up @@ -178,6 +178,12 @@ def get_dim_emb(self) -> int:
"""Returns the embedding dimension g2."""
return self.g2_dim

def distinguish_types(self) -> bool:
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return False

@property
def dim_out(self):
"""Returns the output dimension of this descriptor."""
Expand Down Expand Up @@ -272,44 +278,41 @@ def compute_input_stats(self, merged):
suma2 = []
mixed_type = "real_natoms_vec" in merged[0]
for system in merged:
index = system["mapping"].unsqueeze(-1).expand(-1, -1, 3)
extended_coord = torch.gather(system["coord"], dim=1, index=index)
extended_coord = extended_coord - system["shift"]
index = system["mapping"]
extended_atype = torch.gather(system["atype"], dim=1, index=index)
nloc = system["atype"].shape[-1]
#######################################################
# dirty hack here! the interface of dataload should be
# redesigned to support descriptors like dpa2
#######################################################
nlist = build_neighbor_list(
extended_coord,
extended_atype,
nloc,
self.rcut,
coord, atype, box, natoms = (
system["coord"],
system["atype"],
system["box"],
system["natoms"],
)
extended_coord, extended_atype, mapping, nlist = process_input(
coord,
atype,
self.get_rcut(),
self.get_sel(),
distinguish_types=False,
distinguish_types=self.distinguish_types(),
box=box,
)
env_mat, _, _ = prod_env_mat_se_a(
extended_coord,
nlist,
system["atype"],
atype,
self.mean,
self.stddev,
self.rcut,
self.rcut_smth,
)
if not mixed_type:
sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt(
env_mat.detach().cpu().numpy(), ndescrpt, system["natoms"]
env_mat.detach().cpu().numpy(), ndescrpt, natoms
)
else:
real_natoms_vec = system["real_natoms_vec"]
sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt(
env_mat.detach().cpu().numpy(),
ndescrpt,
system["real_natoms_vec"],
real_natoms_vec,
mixed_type=mixed_type,
real_atype=system["atype"].detach().cpu().numpy(),
real_atype=atype.detach().cpu().numpy(),
)
sumr.append(sysr)
suma.append(sysa)
Expand Down
34 changes: 27 additions & 7 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
from deepmd.pt.model.network.network import (
TypeFilter,
)
from deepmd.pt.utils.nlist import (
process_input,
)

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -100,7 +103,7 @@ def distinguish_types(self):
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return True
return self.sea.distinguish_types()

@property
def dim_out(self):
Expand Down Expand Up @@ -347,6 +350,12 @@ def get_dim_in(self) -> int:
"""Returns the input dimension."""
return self.dim_in

def distinguish_types(self) -> bool:
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return True

@property
def dim_out(self):
"""Returns the output dimension of this descriptor."""
Expand Down Expand Up @@ -381,20 +390,31 @@ def compute_input_stats(self, merged):
sumr2 = []
suma2 = []
for system in merged:
index = system["mapping"].unsqueeze(-1).expand(-1, -1, 3)
extended_coord = torch.gather(system["coord"], dim=1, index=index)
extended_coord = extended_coord - system["shift"]
coord, atype, box, natoms = (
system["coord"],
system["atype"],
system["box"],
system["natoms"],
)
extended_coord, extended_atype, mapping, nlist = process_input(
coord,
atype,
self.get_rcut(),
self.get_sel(),
distinguish_types=self.distinguish_types(),
box=box,
)
env_mat, _, _ = prod_env_mat_se_a(
extended_coord,
system["nlist"],
system["atype"],
nlist,
atype,
self.mean,
self.stddev,
self.rcut,
self.rcut_smth,
)
sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt(
env_mat.detach().cpu().numpy(), self.ndescrpt, system["natoms"]
env_mat.detach().cpu().numpy(), self.ndescrpt, natoms
)
sumr.append(sysr)
suma.append(sysa)
Expand Down
37 changes: 29 additions & 8 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.nlist import (
process_input,
)


@DescriptorBlock.register("se_atten")
Expand Down Expand Up @@ -161,6 +164,12 @@ def get_dim_emb(self) -> int:
"""Returns the output dimension of embedding."""
return self.filter_neuron[-1]

def distinguish_types(self) -> bool:
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return False

@property
def dim_out(self):
"""Returns the output dimension of this descriptor."""
Expand All @@ -185,29 +194,41 @@ def compute_input_stats(self, merged):
suma2 = []
mixed_type = "real_natoms_vec" in merged[0]
for system in merged:
index = system["mapping"].unsqueeze(-1).expand(-1, -1, 3)
extended_coord = torch.gather(system["coord"], dim=1, index=index)
extended_coord = extended_coord - system["shift"]
coord, atype, box, natoms = (
system["coord"],
system["atype"],
system["box"],
system["natoms"],
)
extended_coord, extended_atype, mapping, nlist = process_input(
coord,
atype,
self.get_rcut(),
self.get_sel(),
distinguish_types=self.distinguish_types(),
box=box,
)
env_mat, _, _ = prod_env_mat_se_a(
extended_coord,
system["nlist"],
system["atype"],
nlist,
atype,
self.mean,
self.stddev,
self.rcut,
self.rcut_smth,
)
if not mixed_type:
sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt(
env_mat.detach().cpu().numpy(), self.ndescrpt, system["natoms"]
env_mat.detach().cpu().numpy(), self.ndescrpt, natoms
)
else:
real_natoms_vec = system["real_natoms_vec"]
sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt(
env_mat.detach().cpu().numpy(),
self.ndescrpt,
system["real_natoms_vec"],
real_natoms_vec,
mixed_type=mixed_type,
real_atype=system["atype"].detach().cpu().numpy(),
real_atype=atype.detach().cpu().numpy(),
)
sumr.append(sysr)
suma.append(sysa)
Expand Down
26 changes: 5 additions & 21 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,8 @@
fit_output_to_model_output,
)
from deepmd.pt.utils.nlist import (
build_neighbor_list,
extend_coord_with_ghosts,
nlist_distinguish_types,
)
from deepmd.pt.utils.region import (
normalize_coord,
process_input,
)


Expand Down Expand Up @@ -92,26 +88,14 @@ def forward_common(
The keys are defined by the `ModelOutputDef`.
"""
nframes, nloc = atype.shape[:2]
if box is not None:
coord_normalized = normalize_coord(
coord.view(nframes, nloc, 3),
box.reshape(nframes, 3, 3),
)
else:
coord_normalized = coord.clone()
extended_coord, extended_atype, mapping = extend_coord_with_ghosts(
coord_normalized, atype, box, self.get_rcut()
)
nlist = build_neighbor_list(
extended_coord,
extended_atype,
nloc,
extended_coord, extended_atype, mapping, nlist = process_input(
coord,
atype,
self.get_rcut(),
self.get_sel(),
distinguish_types=self.distinguish_types(),
box=box,
)
extended_coord = extended_coord.view(nframes, -1, 3)
model_predict_lower = self.forward_common_lower(
extended_coord,
extended_atype,
Expand Down
21 changes: 1 addition & 20 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,26 +267,7 @@ def collate_batch(batch):
example = batch[0]
result = example.copy()
for key in example.keys():
if key == "shift" or key == "mapping":
natoms_extended = max([d[key].shape[0] for d in batch])
n_frames = len(batch)
list = []
for x in range(n_frames):
list.append(batch[x][key])
if key == "shift":
result[key] = torch.zeros(
(n_frames, natoms_extended, 3),
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
)
else:
result[key] = torch.zeros(
(n_frames, natoms_extended),
dtype=torch.long,
)
for i in range(len(batch)):
natoms_tmp = list[i].shape[0]
result[key][i, :natoms_tmp] = list[i]
elif "find_" in key:
if "find_" in key:
result[key] = batch[0][key]
else:
if batch[0][key] is None:
Expand Down
Loading

0 comments on commit 0bcd55c

Please sign in to comment.