Skip to content

Commit

Permalink
feat(pt): support multitask argcheck (#3925)
Browse files Browse the repository at this point in the history
Note that:
1. docs for multitask args are not supported, may need help.
2. `trim_pattern="_*"` is not supported for repeat dict Argument, may
need to update dargs.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Enhanced training configuration to support multi-task mode with
additional arguments for data configuration.
  - Updated example configurations to reflect multi-task mode changes.

- **Bug Fixes**
- Improved logic for updating and normalizing configuration during
training regardless of multi-task mode.

- **Dependencies**
  - Upgraded `dargs` package requirement to version `>= 0.4.7`.

- **Tests**
- Added new test cases for multi-task scenarios in `TestExamples` class.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
iProzd authored Jul 2, 2024
1 parent e809e64 commit c98185c
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 22 deletions.
5 changes: 2 additions & 3 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,8 @@ def train(FLAGS):
)

# argcheck
if not multi_task:
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config)
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config, multi_task=multi_task)

# do neighbor stat
min_nbor_dist = None
Expand Down
82 changes: 67 additions & 15 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2325,7 +2325,9 @@ def mixed_precision_args(): # ! added by Denghui.
)


def training_args(): # ! modified by Ziyao: data configuration isolated.
def training_args(
multi_task=False,
): # ! modified by Ziyao: data configuration isolated.
doc_numb_steps = "Number of training batch. Each training uses one batch of data."
doc_seed = "The random seed for getting frames from the training data set."
doc_disp_file = "The file for printing learning curve."
Expand Down Expand Up @@ -2364,14 +2366,30 @@ def training_args(): # ! modified by Ziyao: data configuration isolated.
)
doc_opt_type = "The type of optimizer to use."
doc_kf_blocksize = "The blocksize for the Kalman filter."
doc_model_prob = "The visiting probability of each model for each training step in the multi-task mode."
doc_data_dict = "The multiple definition of the data, used in the multi-task mode."

arg_training_data = training_data_args()
arg_validation_data = validation_data_args()
mixed_precision_data = mixed_precision_args()

args = [
data_args = [
arg_training_data,
arg_validation_data,
Argument(
"stat_file", str, optional=True, doc=doc_only_pt_supported + doc_stat_file
),
]
args = (
data_args
if not multi_task
else [
Argument("model_prob", dict, optional=True, default={}, doc=doc_model_prob),
Argument("data_dict", dict, data_args, repeat=True, doc=doc_data_dict),
]
)

args += [
mixed_precision_data,
Argument(
"numb_steps", int, optional=False, doc=doc_numb_steps, alias=["stop_batch"]
Expand Down Expand Up @@ -2438,9 +2456,6 @@ def training_args(): # ! modified by Ziyao: data configuration isolated.
optional=True,
doc=doc_only_pt_supported + doc_gradient_max_norm,
),
Argument(
"stat_file", str, optional=True, doc=doc_only_pt_supported + doc_stat_file
),
]
variants = [
Variant(
Expand Down Expand Up @@ -2472,6 +2487,34 @@ def training_args(): # ! modified by Ziyao: data configuration isolated.
return Argument("training", dict, args, variants, doc=doc_training)


def multi_model_args():
model_dict = model_args()
model_dict.name = "model_dict"
model_dict.repeat = True
model_dict.doc = (
"The multiple definition of the model, used in the multi-task mode."
)
doc_shared_dict = "The definition of the shared parameters used in the `model_dict` within multi-task mode."
return Argument(
"model",
dict,
[
model_dict,
Argument(
"shared_dict", dict, optional=True, default={}, doc=doc_shared_dict
),
],
)


def multi_loss_args():
loss_dict = loss_args()
loss_dict.name = "loss_dict"
loss_dict.repeat = True
loss_dict.doc = "The multiple definition of the loss, used in the multi-task mode."
return loss_dict


def make_index(keys):
ret = []
for ii in keys:
Expand Down Expand Up @@ -2502,14 +2545,23 @@ def gen_json(**kwargs):
)


def gen_args(**kwargs) -> List[Argument]:
return [
model_args(),
learning_rate_args(),
loss_args(),
training_args(),
nvnmd_args(),
]
def gen_args(multi_task=False) -> List[Argument]:
if not multi_task:
return [
model_args(),
learning_rate_args(),
loss_args(),
training_args(multi_task=multi_task),
nvnmd_args(),
]
else:
return [
multi_model_args(),
learning_rate_args(),
multi_loss_args(),
training_args(multi_task=multi_task),
nvnmd_args(),
]


def gen_json_schema() -> str:
Expand All @@ -2524,8 +2576,8 @@ def gen_json_schema() -> str:
return json.dumps(generate_json_schema(arg))


def normalize(data):
base = Argument("base", dict, gen_args())
def normalize(data, multi_task=False):
base = Argument("base", dict, gen_args(multi_task=multi_task))
data = base.normalize_value(data, trim_pattern="_*")
base.check_value(data, strict=True)

Expand Down
1 change: 0 additions & 1 deletion examples/water_multi_task/pytorch_example/input_torch.json
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
"_comment": "that's all"
},
"loss_dict": {
"_comment": " that's all",
"water_1": {
"type": "ener",
"start_pref_e": 0.02,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ dependencies = [
'numpy',
'scipy',
'pyyaml',
'dargs >= 0.4.6',
'dargs >= 0.4.7',
'typing_extensions; python_version < "3.8"',
'importlib_metadata>=1.4; python_version < "3.8"',
'h5py',
Expand Down
15 changes: 13 additions & 2 deletions source/tests/common/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
normalize,
)

from ..pt.test_multitask import (
preprocess_shared_params,
)

p_examples = Path(__file__).parent.parent.parent.parent / "examples"

input_files = (
Expand Down Expand Up @@ -51,11 +55,18 @@
p_examples / "water" / "dpa2" / "input_torch.json",
)

input_files_multi = (
p_examples / "water_multi_task" / "pytorch_example" / "input_torch.json",
)


class TestExamples(unittest.TestCase):
def test_arguments(self):
for fn in input_files:
for fn in input_files + input_files_multi:
multi_task = fn in input_files_multi
fn = str(fn)
with self.subTest(fn=fn):
jdata = j_loader(fn)
normalize(jdata)
if multi_task:
jdata["model"], _ = preprocess_shared_params(jdata["model"])
normalize(jdata, multi_task=multi_task)

0 comments on commit c98185c

Please sign in to comment.