Skip to content

Commit

Permalink
Support no validation
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 28, 2024
1 parent f1585b2 commit 463f9fb
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 35 deletions.
19 changes: 13 additions & 6 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,12 @@ def prepare_trainer_input_single(
type_split = False
if model_params_single["descriptor"]["type"] in ["se_e2_a"]:
type_split = True
validation_dataset_params = data_dict_single["validation_data"]
validation_dataset_params = data_dict_single.get("validation_data", None)
validation_systems = (
validation_dataset_params["systems"] if validation_dataset_params else None
)
training_systems = training_dataset_params["systems"]
validation_systems = validation_dataset_params["systems"]

# stat files
stat_file_path_single = data_dict_single.get("stat_file", None)
if stat_file_path_single is not None:
Expand All @@ -107,10 +110,14 @@ def prepare_trainer_input_single(
stat_file_path_single = DPPath(stat_file_path_single, "a")

# validation and training data
validation_data_single = DpLoaderSet(
validation_systems,
validation_dataset_params["batch_size"],
model_params_single,
validation_data_single = (
DpLoaderSet(
validation_systems,
validation_dataset_params["batch_size"],
model_params_single,
)
if validation_systems
else None
)
if ckpt or finetune_model:
train_data_single = DpLoaderSet(
Expand Down
72 changes: 43 additions & 29 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
if torch.__version__.startswith("2"):
import torch._dynamo


import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import (
Expand Down Expand Up @@ -142,20 +141,7 @@ def get_data_loader(_training_data, _validation_data, _training_params):
else:
train_sampler = get_weighted_sampler(_training_data, "prob_sys_size")

if "auto_prob" in _training_params["validation_data"]:
valid_sampler = get_weighted_sampler(
_validation_data, _training_params["validation_data"]["auto_prob"]
)
elif "sys_probs" in _training_params["validation_data"]:
valid_sampler = get_weighted_sampler(
_validation_data,
_training_params["validation_data"]["sys_probs"],
sys_prob=True,
)
else:
valid_sampler = get_weighted_sampler(_validation_data, "prob_sys_size")

if train_sampler is None or valid_sampler is None:
if train_sampler is None:
log.warning(
"Sampler not specified!"
) # None sampler will lead to a premature stop iteration. Replacement should be True in attribute of the sampler to produce expected number of items in one iteration.
Expand All @@ -169,22 +155,43 @@ def get_data_loader(_training_data, _validation_data, _training_params):
)
with torch.device("cpu"):
training_data_buffered = BufferedIterator(iter(training_dataloader))
validation_dataloader = DataLoader(
_validation_data,
sampler=valid_sampler,
batch_size=None,
num_workers=min(NUM_WORKERS, 1),
drop_last=False,
pin_memory=True,
)

with torch.device("cpu"):
validation_data_buffered = BufferedIterator(iter(validation_dataloader))
if _training_params.get("validation_data", None) is not None:
valid_numb_batch = _training_params["validation_data"].get(
"numb_btch", 1
if _validation_data is not None:
if "auto_prob" in _training_params["validation_data"]:
valid_sampler = get_weighted_sampler(
_validation_data,
_training_params["validation_data"]["auto_prob"],
)
elif "sys_probs" in _training_params["validation_data"]:
valid_sampler = get_weighted_sampler(
_validation_data,
_training_params["validation_data"]["sys_probs"],
sys_prob=True,
)
else:
valid_sampler = get_weighted_sampler(
_validation_data, "prob_sys_size"
)
validation_dataloader = DataLoader(
_validation_data,
sampler=valid_sampler,
batch_size=None,
num_workers=min(NUM_WORKERS, 1),
drop_last=False,
pin_memory=True,
)
with torch.device("cpu"):
validation_data_buffered = BufferedIterator(
iter(validation_dataloader)
)
if _training_params.get("validation_data", None) is not None:
valid_numb_batch = _training_params["validation_data"].get(
"numb_btch", 1
)
else:
valid_numb_batch = 1
else:
validation_dataloader = None
validation_data_buffered = None
valid_numb_batch = 1
return (
training_dataloader,
Expand Down Expand Up @@ -645,6 +652,9 @@ def log_loss_valid(_task_key="Default"):
input_dict, label_dict, _ = self.get_data(
is_train=False, task_key=_task_key
)
if input_dict == {}:
# no validation data
return "", None
_, loss, more_loss = self.wrapper(
**input_dict,
cur_lr=pref_lr,
Expand Down Expand Up @@ -806,6 +816,8 @@ def get_data(self, is_train=True, task_key="Default"):
)
batch_data = next(iter(self.training_data))
else:
if self.validation_data is None:
return {}, {}, {}
try:
batch_data = next(iter(self.validation_data))
except StopIteration:
Expand All @@ -824,6 +836,8 @@ def get_data(self, is_train=True, task_key="Default"):
)
batch_data = next(iter(self.training_data[task_key]))
else:
if self.validation_data[task_key] is None:
return {}, {}, {}
try:
batch_data = next(iter(self.validation_data[task_key]))
except StopIteration:
Expand Down

0 comments on commit 463f9fb

Please sign in to comment.