Skip to content

Commit

Permalink
Add compression API
Browse files Browse the repository at this point in the history
  • Loading branch information
cherryWangY committed Nov 1, 2024
1 parent 8355947 commit ae64a72
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 0 deletions.
25 changes: 25 additions & 0 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,31 @@ def change_type_map(
self.out_bias = self.out_bias[:, remap_index, :]
self.out_std = self.out_std[:, remap_index, :]

def enable_compression(
self,
min_nbor_dist: float,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Call descriptor enable_compression()
Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
raise NotImplementedError("This atomi model doesn't support compression!")

def forward_common_atomic(
self,
extended_coord: np.ndarray,
Expand Down
31 changes: 31 additions & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,37 @@ def has_message_passing(self) -> bool:
def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the atomic model needs sorted nlist when using `forward_lower`."""
return self.descriptor.need_sorted_nlist_for_lower()

def enable_compression(
self,
min_nbor_dist: float,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Call descriptor enable_compression()
Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
self.descriptor.enable_compression(
min_nbor_dist,
table_extrapolate,
table_stride_1,
table_stride_2,
check_frequency,
)

def forward_atomic(
self,
Expand Down
30 changes: 30 additions & 0 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,36 @@ def model_output_type(self) -> list[str]:
if vv.category == OutputVariableCategory.OUT
]
return vars

def enable_compression(
self,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Call atomic_model enable_compression()
Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
self.atomic_model.enable_compression(
self.get_min_nbor_dist(),
table_extrapolate,
table_stride_1,
table_stride_2,
check_frequency,
)

def call(
self,
Expand Down

0 comments on commit ae64a72

Please sign in to comment.