Skip to content

Commit

Permalink
fix(pt): use user seed in DpLoaderSet (deepmodeling#4015)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced a `seed` parameter for enhanced control over randomness in
the training process.
- Improved randomness handling for distributed training setups to ensure
unique batch sequences.

- **Bug Fixes**
- Enhanced robustness in seed setup during initialization to prevent
unnecessary function calls with a `None` value.

- **Chores**
- Streamlined the random seed management by removing redundant
seed-setting logic in training processes.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
iProzd authored and Mathieu Taillefumier committed Sep 18, 2024
1 parent 22cae4f commit e77bdfa
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
11 changes: 10 additions & 1 deletion deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def get_trainer(
assert dist.is_nccl_available()
dist.init_process_group(backend="nccl")

def prepare_trainer_input_single(model_params_single, data_dict_single, rank=0):
def prepare_trainer_input_single(
model_params_single, data_dict_single, rank=0, seed=None
):
training_dataset_params = data_dict_single["training_data"]
validation_dataset_params = data_dict_single.get("validation_data", None)
validation_systems = (
Expand All @@ -134,11 +136,14 @@ def prepare_trainer_input_single(model_params_single, data_dict_single, rank=0):
stat_file_path_single = DPPath(stat_file_path_single, "a")

# validation and training data
# avoid the same batch sequence among devices
rank_seed = (seed + rank) % (2**32) if seed is not None else None
validation_data_single = (
DpLoaderSet(
validation_systems,
validation_dataset_params["batch_size"],
model_params_single["type_map"],
seed=rank_seed,
)
if validation_systems
else None
Expand All @@ -147,6 +152,7 @@ def prepare_trainer_input_single(model_params_single, data_dict_single, rank=0):
training_systems,
training_dataset_params["batch_size"],
model_params_single["type_map"],
seed=rank_seed,
)
return (
train_data_single,
Expand All @@ -155,6 +161,7 @@ def prepare_trainer_input_single(model_params_single, data_dict_single, rank=0):
)

rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
data_seed = config["training"].get("seed", None)
if not multi_task:
(
train_data,
Expand All @@ -164,6 +171,7 @@ def prepare_trainer_input_single(model_params_single, data_dict_single, rank=0):
config["model"],
config["training"],
rank=rank,
seed=data_seed,
)
else:
train_data, validation_data, stat_file_path = {}, {}, {}
Expand All @@ -176,6 +184,7 @@ def prepare_trainer_input_single(model_params_single, data_dict_single, rank=0):
config["model"]["model_dict"][model_key],
config["training"]["data_dict"][model_key],
rank=rank,
seed=data_seed,
)

trainer = training.Trainer(
Expand Down
5 changes: 0 additions & 5 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,6 @@ def get_lr(lr_params):
self.opt_type, self.opt_param = get_opt_param(training_params)

# Model
dp_random.seed(training_params["seed"])
if training_params["seed"] is not None:
torch.manual_seed(training_params["seed"])

self.model = get_model_for_wrapper(model_params)

# Loss
Expand All @@ -302,7 +298,6 @@ def get_lr(lr_params):
)

# Data
dp_random.seed(training_params["seed"])
if not self.multi_task:
self.get_sample_func = single_model_stat(
self.model,
Expand Down
5 changes: 3 additions & 2 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,11 @@ def __init__(
systems,
batch_size,
type_map,
seed=10,
seed=None,
shuffle=True,
):
setup_seed(seed)
if seed is not None:
setup_seed(seed)
if isinstance(systems, str):
with h5py.File(systems) as file:
systems = [os.path.join(systems, item) for item in file.keys()]
Expand Down

0 comments on commit e77bdfa

Please sign in to comment.