Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: expand systems before training #3384

Merged
merged 1 commit into from
Mar 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@
from deepmd.utils.compat import (
update_deepmd_input,
)
from deepmd.utils.data_system import (
process_systems,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -108,6 +111,9 @@ def prepare_trainer_input_single(
validation_dataset_params["systems"] if validation_dataset_params else None
)
training_systems = training_dataset_params["systems"]
training_systems = process_systems(training_systems)
if validation_systems is not None:
validation_systems = process_systems(validation_systems)

# stat files
stat_file_path_single = data_dict_single.get("stat_file", None)
Expand Down
55 changes: 38 additions & 17 deletions deepmd/utils/data_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Dict,
List,
Optional,
Union,
)

import numpy as np
Expand Down Expand Up @@ -667,30 +668,22 @@ def prob_sys_size_ext(keywords, nsystems, nbatch):
return sys_probs


def get_data(
jdata: Dict[str, Any], rcut, type_map, modifier, multi_task_mode=False
) -> DeepmdDataSystem:
"""Get the data system.
def process_systems(systems: Union[str, List[str]]) -> List[str]:
"""Process the user-input systems.

If it is a single directory, search for all the systems in the directory.
Check if the systems are valid.

Parameters
----------
jdata
The json data
rcut
The cut-off radius, not used
type_map
The type map
modifier
The data modifier
multi_task_mode
If in multi task mode
systems : str or list of str
The user-input systems

Returns
-------
DeepmdDataSystem
The data system
list of str
The valid systems
"""
systems = j_must_have(jdata, "systems")
if isinstance(systems, str):
systems = expand_sys_str(systems)
elif isinstance(systems, list):
Expand All @@ -712,6 +705,34 @@ def get_data(
msg = f"dir {ii} is not a valid data system dir"
log.fatal(msg)
raise OSError(msg, help_msg)
return systems


def get_data(
jdata: Dict[str, Any], rcut, type_map, modifier, multi_task_mode=False
) -> DeepmdDataSystem:
"""Get the data system.

Parameters
----------
jdata
The json data
rcut
The cut-off radius, not used
type_map
The type map
modifier
The data modifier
multi_task_mode
If in multi task mode

Returns
-------
DeepmdDataSystem
The data system
"""
systems = j_must_have(jdata, "systems")
systems = process_systems(systems)

batch_size = j_must_have(jdata, "batch_size")
sys_probs = jdata.get("sys_probs", None)
Expand Down
2 changes: 2 additions & 0 deletions deepmd/utils/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,8 @@

def is_dir(self) -> bool:
"""Check if self is directory."""
if self._name == "/":
return True

Check warning on line 418 in deepmd/utils/path.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/path.py#L418

Added line #L418 was not covered by tests
if self._name not in self._keys:
return False
return isinstance(self.root[self._name], h5py.Group)
Expand Down