Skip to content

Commit

Permalink
chore(pt): Change the type of do_message_passing from int to `boo…
Browse files Browse the repository at this point in the history
…l` in `DeepPotPT` and `DeepSpinPT` classes (#4391)

Fix #4366.

* Update the type of `do_message_passing` to `bool` in the `DeepPotPT`
class and `init` method in `source/api_cc/include/DeepPotPT.h` and
`source/api_cc/src/DeepPotPT.cc`
* Update the type of `do_message_passing` to `bool` in the `DeepSpinPT`
class and `init` method in `source/api_cc/include/DeepSpinPT.h` and
`source/api_cc/src/DeepSpinPT.cc`

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Enhanced error handling for exceptions from the PyTorch library in
both `DeepPotPT` and `DeepSpinPT` classes.
- Simplified boolean checks for message passing in the `compute` methods
of both classes.

- **Bug Fixes**
- Improved robustness in the `DeepPotPT` constructor to prevent resource
leaks during initialization.

- **Documentation**
- Updated method signatures to reflect changes in parameter types and
structures.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] authored Nov 21, 2024
1 parent dbf450f commit ec6e903
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion source/api_cc/include/DeepPotPT.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ class DeepPotPT : public DeepPotBackend {
NeighborListData nlist_data;
int max_num_neighbors;
int gpu_id;
int do_message_passing; // 1:dpa2 model 0:others
bool do_message_passing; // 1:dpa2 model 0:others
bool gpu_enabled;
at::Tensor firstneigh_tensor;
c10::optional<torch::Tensor> mapping_tensor;
Expand Down
2 changes: 1 addition & 1 deletion source/api_cc/include/DeepSpinPT.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ class DeepSpinPT : public DeepSpinBackend {
NeighborListData nlist_data;
int max_num_neighbors;
int gpu_id;
int do_message_passing; // 1:dpa2 model 0:others
bool do_message_passing; // 1:dpa2 model 0:others
bool gpu_enabled;
at::Tensor firstneigh_tensor;
c10::optional<torch::Tensor> mapping_tensor;
Expand Down
4 changes: 2 additions & 2 deletions source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
nlist_data.copy_from_nlist(lmp_list);
nlist_data.shuffle_exclude_empty(fwd_map);
nlist_data.padding();
if (do_message_passing == 1) {
if (do_message_passing) {
int nswap = lmp_list.nswap;
torch::Tensor sendproc_tensor =
torch::from_blob(lmp_list.sendproc, {nswap}, int32_option);
Expand Down Expand Up @@ -234,7 +234,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
.to(device);
}
c10::Dict<c10::IValue, c10::IValue> outputs =
(do_message_passing == 1)
(do_message_passing)
? module
.run_method("forward_lower", coord_wrapped_Tensor, atype_Tensor,
firstneigh_tensor, mapping_tensor, fparam_tensor,
Expand Down
4 changes: 2 additions & 2 deletions source/api_cc/src/DeepSpinPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
nlist_data.copy_from_nlist(lmp_list);
nlist_data.shuffle_exclude_empty(fwd_map);
nlist_data.padding();
if (do_message_passing == 1) {
if (do_message_passing) {
int nswap = lmp_list.nswap;
torch::Tensor sendproc_tensor =
torch::from_blob(lmp_list.sendproc, {nswap}, int32_option);
Expand Down Expand Up @@ -234,7 +234,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
.to(device);
}
c10::Dict<c10::IValue, c10::IValue> outputs =
(do_message_passing == 1)
(do_message_passing)
? module
.run_method("forward_lower", coord_wrapped_Tensor, atype_Tensor,
spin_wrapped_Tensor, firstneigh_tensor,
Expand Down

0 comments on commit ec6e903

Please sign in to comment.