Skip to content

Commit

Permalink
fix(pt): fix not used sys_probs
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Nov 13, 2024
1 parent 7416c9f commit 302cdf2
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 16 deletions.
22 changes: 6 additions & 16 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
)
from deepmd.pt.utils.dataloader import (
BufferedIterator,
get_weighted_sampler,
get_sampler_from_params,
)
from deepmd.pt.utils.env import (
DEVICE,
Expand Down Expand Up @@ -160,19 +160,7 @@ def get_opt_param(params):

def get_data_loader(_training_data, _validation_data, _training_params):
def get_dataloader_and_buffer(_data, _params):
if "auto_prob" in _training_params["training_data"]:
_sampler = get_weighted_sampler(
_data, _params["training_data"]["auto_prob"]
)
elif "sys_probs" in _training_params["training_data"]:
_sampler = get_weighted_sampler(
_data,
_params["training_data"]["sys_probs"],
sys_prob=True,
)
else:
_sampler = get_weighted_sampler(_data, "prob_sys_size")

_sampler = get_sampler_from_params(_data, _params)
if _sampler is None:
log.warning(
"Sampler not specified!"
Expand All @@ -193,14 +181,16 @@ def get_dataloader_and_buffer(_data, _params):
return _dataloader, _data_buffered

training_dataloader, training_data_buffered = get_dataloader_and_buffer(
_training_data, _training_params
_training_data, _training_params["training_data"]
)

if _validation_data is not None:
(
validation_dataloader,
validation_data_buffered,
) = get_dataloader_and_buffer(_validation_data, _training_params)
) = get_dataloader_and_buffer(
_validation_data, _training_params["validation_data"]
)
valid_numb_batch = _training_params["validation_data"].get(
"numb_btch", 1
)
Expand Down
16 changes: 16 additions & 0 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,19 @@ def get_weighted_sampler(training_data, prob_style, sys_prob=False):
with torch.device("cpu"):
sampler = WeightedRandomSampler(probs, len_sampler, replacement=True)
return sampler


def get_sampler_from_params(_data, _params):
if (
"sys_probs" in _params and _params["sys_probs"] is not None
): # use sys_probs first
_sampler = get_weighted_sampler(
_data,
_params["sys_probs"],
sys_prob=True,
)
elif "auto_prob" in _params:
_sampler = get_weighted_sampler(_data, _params["auto_prob"])
else:
_sampler = get_weighted_sampler(_data, "prob_sys_size")
return _sampler
22 changes: 22 additions & 0 deletions source/tests/pt/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from deepmd.pt.utils.dataloader import (
DpLoaderSet,
get_sampler_from_params,
get_weighted_sampler,
)
from deepmd.tf.common import (
Expand Down Expand Up @@ -105,6 +106,27 @@ def test_sys_probs(self):
dp_probs = np.array(self.dp_dataset.sys_probs)
self.assertTrue(np.allclose(my_probs, dp_probs))

def test_sys_probs_end2end(self):
sys_probs = [0.1, 0.4, 0.5]
_params = {
"sys_probs": sys_probs,
"auto_prob": "prob_sys_size",
} # use sys_probs first
sampler = get_sampler_from_params(self.my_dataset, _params)
my_probs = np.array(sampler.weights)
self.dp_dataset.set_sys_probs(sys_probs=sys_probs)
dp_probs = np.array(self.dp_dataset.sys_probs)
self.assertTrue(np.allclose(my_probs, dp_probs))

def test_auto_prob_sys_size_ext_end2end(self):
auto_prob_style = "prob_sys_size;0:1:0.2;1:3:0.8"
_params = {"sys_probs": None, "auto_prob": auto_prob_style} # use auto_prob
sampler = get_sampler_from_params(self.my_dataset, _params)
my_probs = np.array(sampler.weights)
self.dp_dataset.set_sys_probs(auto_prob_style=auto_prob_style)
dp_probs = np.array(self.dp_dataset.sys_probs)
self.assertTrue(np.allclose(my_probs, dp_probs))


if __name__ == "__main__":
unittest.main()

0 comments on commit 302cdf2

Please sign in to comment.