diff --git a/deepmd/dpmodel/descriptor/se_t.py b/deepmd/dpmodel/descriptor/se_t.py index d9b2741b3e..eac6a9640e 100644 --- a/deepmd/dpmodel/descriptor/se_t.py +++ b/deepmd/dpmodel/descriptor/se_t.py @@ -9,6 +9,9 @@ from deepmd.env import ( GLOBAL_NP_FLOAT_PRECISION, ) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) from deepmd.utils.path import ( DPPath, ) @@ -348,15 +351,32 @@ def deserialize(cls, data: dict) -> "DescrptSeT": return obj @classmethod - def update_sel(cls, global_jdata: dict, local_jdata: dict): + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: Optional[List[str]], + local_jdata: dict, + ) -> Tuple[dict, Optional[float]]: """Update the selection and perform neighbor statistics. Parameters ---------- - global_jdata : dict - The global data, containing the training section + train_data : DeepmdDataSystem + data used to do neighbor statictics + type_map : list[str], optional + The name of each type of atoms local_jdata : dict The local data refer to the current class + + Returns + ------- + dict + The updated local data + float + The minimum distance between two atoms """ local_jdata_cpy = local_jdata.copy() - return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, False) + min_nbor_dist, local_jdata_cpy["sel"] = UpdateSel().update_one_sel( + train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], False + ) + return local_jdata_cpy, min_nbor_dist diff --git a/deepmd/pt/model/descriptor/se_t.py b/deepmd/pt/model/descriptor/se_t.py index db6244000d..2c8f52709f 100644 --- a/deepmd/pt/model/descriptor/se_t.py +++ b/deepmd/pt/model/descriptor/se_t.py @@ -30,6 +30,9 @@ from deepmd.pt.utils.update_sel import ( UpdateSel, ) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) from deepmd.utils.env_mat_stat import ( StatItem, ) @@ -324,18 +327,35 @@ def t_cvt(xx): return obj @classmethod - def update_sel(cls, global_jdata: dict, local_jdata: dict): + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: Optional[List[str]], + local_jdata: dict, + ) -> Tuple[dict, Optional[float]]: """Update the selection and perform neighbor statistics. Parameters ---------- - global_jdata : dict - The global data, containing the training section + train_data : DeepmdDataSystem + data used to do neighbor statictics + type_map : list[str], optional + The name of each type of atoms local_jdata : dict The local data refer to the current class + + Returns + ------- + dict + The updated local data + float + The minimum distance between two atoms """ local_jdata_cpy = local_jdata.copy() - return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, False) + min_nbor_dist, local_jdata_cpy["sel"] = UpdateSel().update_one_sel( + train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], False + ) + return local_jdata_cpy, min_nbor_dist @DescriptorBlock.register("se_e3")