Skip to content

Commit

Permalink
update doc and example
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Nov 25, 2024
1 parent ae95c3c commit 412aea9
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 0 deletions.
7 changes: 7 additions & 0 deletions doc/train/multi-task-training.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ Specifically, there are several parts that need to be modified:
and use the user-defined integer `shared_level` in the code to share the corresponding module to varying degrees
(default is to share all parameters, i.e., `shared_level`=0).
The parts that are exclusive to each model can be written following the previous definition.
- - To conduct multitask training, there are two typical approaches:
1. **Descriptor sharing only**: Share the descriptor with `shared_level`=0. See [here](../../examples/water_multi_task/pytorch_example/input_torch.json) for an example.
2. **Descriptor and fitting network sharing with data identification**:
- Share the descriptor with `shared_level`=0.
- Share the fitting network with `shared_level`=1.
- {ref}`numb_dataid <model[standard]/fitting_net[ener]/numb_dataid>` must be set to the number of model branches, which will distinguish different data tasks using a one-hot embedding.
- See [here](../../examples/water_multi_task/pytorch_example/input_torch_sharefit.json) for an example.

- {ref}`loss_dict <loss_dict>`: The loss settings corresponding to each task model, specified by the `model_key`.
Each {ref}`loss_dict/model_key <loss_dict/model_key>` contains the corresponding loss settings,
Expand Down
155 changes: 155 additions & 0 deletions examples/water_multi_task/pytorch_example/input_torch_sharefit.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
{
"_comment": "that's all",
"model": {
"shared_dict": {
"type_map_all": [
"O",
"H"
],
"dpa2_descriptor": {
"type": "dpa2",
"repinit": {
"tebd_dim": 8,
"rcut": 6.0,
"rcut_smth": 0.5,
"nsel": 120,
"neuron": [
25,
50,
100
],
"axis_neuron": 12,
"activation_function": "tanh",
"three_body_sel": 48,
"three_body_rcut": 4.0,
"three_body_rcut_smth": 3.5,
"use_three_body": true
},
"repformer": {
"rcut": 4.0,
"rcut_smth": 3.5,
"nsel": 48,
"nlayers": 6,
"g1_dim": 128,
"g2_dim": 32,
"attn2_hidden": 32,
"attn2_nhead": 4,
"attn1_hidden": 128,
"attn1_nhead": 4,
"axis_neuron": 4,
"update_h2": false,
"update_g1_has_conv": true,
"update_g1_has_grrg": true,
"update_g1_has_drrd": true,
"update_g1_has_attn": false,
"update_g2_has_g1g1": false,
"update_g2_has_attn": true,
"update_style": "res_residual",
"update_residual": 0.01,
"update_residual_init": "norm",
"attn2_has_gate": true,
"use_sqrt_nnei": true,
"g1_out_conv": true,
"g1_out_mlp": true
},
"precision": "float64",
"add_tebd_to_repinit_out": false,
"_comment": " that's all"
},
"shared_fit_with_id": {
"neuron": [
240,
240,
240
],
"resnet_dt": true,
"seed": 1,
"numb_dataid": 2,
"_comment": " that's all"
},
"_comment": "that's all"
},
"model_dict": {
"water_1": {
"type_map": "type_map_all",
"descriptor": "dpa2_descriptor",
"fitting_net": "shared_fit_with_id:1"
},
"water_2": {
"type_map": "type_map_all",
"descriptor": "dpa2_descriptor",
"fitting_net": "shared_fit_with_id:1"
}
}
},
"learning_rate": {
"type": "exp",
"decay_steps": 5000,
"start_lr": 0.001,
"stop_lr": 3.51e-08,
"_comment": "that's all"
},
"loss_dict": {
"water_1": {
"type": "ener",
"start_pref_e": 0.02,
"limit_pref_e": 1,
"start_pref_f": 1000,
"limit_pref_f": 1,
"start_pref_v": 0,
"limit_pref_v": 0
},
"water_2": {
"type": "ener",
"start_pref_e": 0.02,
"limit_pref_e": 1,
"start_pref_f": 1000,
"limit_pref_f": 1,
"start_pref_v": 0,
"limit_pref_v": 0
}
},
"training": {
"model_prob": {
"water_1": 0.5,
"water_2": 0.5
},
"data_dict": {
"water_1": {
"training_data": {
"systems": [
"../../water/data/data_0/",
"../../water/data/data_1/",
"../../water/data/data_2/"
],
"batch_size": 1,
"_comment": "that's all"
},
"validation_data": {
"systems": [
"../../water/data/data_3/"
],
"batch_size": 1,
"_comment": "that's all"
}
},
"water_2": {
"training_data": {
"systems": [
"../../water/data/data_0/",
"../../water/data/data_1/",
"../../water/data/data_2/"
],
"batch_size": 1,
"_comment": "that's all"
}
}
},
"numb_steps": 100000,
"seed": 10,
"disp_file": "lcurve.out",
"disp_freq": 100,
"save_freq": 100,
"_comment": "that's all"
}
}
1 change: 1 addition & 0 deletions source/tests/common/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@

input_files_multi = (
p_examples / "water_multi_task" / "pytorch_example" / "input_torch.json",
p_examples / "water_multi_task" / "pytorch_example" / "input_torch_sharefit.json",
)


Expand Down

0 comments on commit 412aea9

Please sign in to comment.