Skip to content

Commit

Permalink
docs: add doc for multitask fine-tuning (#3717)
Browse files Browse the repository at this point in the history
Add docs for multitask fine-tuning.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit


- **Documentation**
- Updated the fine-tuning guide with new sections on TensorFlow and
PyTorch implementations.
- Added detailed instructions for fine-tuning methods in PyTorch,
including specific commands and configurations.
- Modified the multi-task training guide to redirect users to the
fine-tuning section for more comprehensive instructions.
- Corrected a typo in the multi-task training TensorFlow documentation
for improved clarity.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
iProzd authored May 6, 2024
1 parent 4b319a0 commit 0ec6719
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 14 deletions.
153 changes: 148 additions & 5 deletions doc/train/finetuning.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Finetune the pretrained model {{ tensorflow_icon }} {{ pytorch_icon }}
# Finetune the pre-trained model {{ tensorflow_icon }} {{ pytorch_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}
Expand All @@ -7,13 +7,16 @@
Pretraining-and-finetuning is a widely used approach in other fields such as Computer Vision (CV) or Natural Language Processing (NLP)
to vastly reduce the training cost, while it's not trivial in potential models.
Compositions and configurations of data samples or even computational parameters in upstream software (such as VASP)
may be different between the pretrained and target datasets, leading to energy shifts or other diversities of training data.
may be different between the pre-trained and target datasets, leading to energy shifts or other diversities of training data.

Recently the emerging of methods such as [DPA-1](https://arxiv.org/abs/2208.08236) has brought us to a new stage where we can
perform similar pretraining-finetuning approaches.
DPA-1 can hopefully learn the common knowledge in the pretrained dataset (especially the `force` information)
They can hopefully learn the common knowledge in the pre-trained dataset (especially the `force` information)
and thus reduce the computational cost in downstream training tasks.
If you have a pretrained model `pretrained.pb`

## TensorFlow Implementation {{ tensorflow_icon }}

If you have a pre-trained model `pretrained.pb`
(here we support models using [`se_atten`](../model/train-se-atten.md) descriptor and [`ener`](../model/train-energy.md) fitting net)
on a large dataset (for example, [OC2M](https://github.com/Open-Catalyst-Project/ocp/blob/main/DATASET.md) in
DPA-1 [paper](https://arxiv.org/abs/2208.08236)), a finetuning strategy can be performed by simply running:
Expand All @@ -26,7 +29,7 @@ The command above will change the energy bias in the last layer of the fitting n
according to the training dataset in input.json.

:::{warning}
Note that the elements in the training dataset must be contained in the pretrained dataset.
Note that the elements in the training dataset must be contained in the pre-trained dataset.
:::

The finetune procedure will inherit the model structures in `pretrained.pb`,
Expand All @@ -45,3 +48,143 @@ To obtain a more simplified script, for example, you can change the {ref}`model
"fitting_net" : {}
}
```

## PyTorch Implementation {{ pytorch_icon }}

In PyTorch version, we have introduced an updated, more adaptable approach to fine-tuning. This methodology encompasses two primary variations:

### Single-task fine-tuning

#### Fine-tuning from a single-task pre-trained model

By saying "single-task pre-trained", we refer to a model pre-trained on one single dataset.
This fine-tuning method is similar to the fine-tune approach supported by TensorFlow.
It utilizes a single-task pre-trained model (`pretrained.pt`) and modifies the energy bias within its fitting net before continuing with training.
The command for this operation is:

```bash
$ dp --pt train input.json --finetune pretrained.pt
```

:::{note}
We do not support fine-tuning from a randomly initialized fitting net in this case, which is the same as implementations in TensorFlow.
:::

The model section in input.json can be simplified as follows:

```json
"model": {
"type_map": ["O", "H"],
"descriptor" : {},
"fitting_net" : {}
}
```

:::{warning}
The `type_map` will be overwritten based on that in the pre-trained model. Please ensure you are familiar with the `type_map` configuration in the pre-trained model before starting the fine-tuning process.
This issue will be addressed in the future version.
:::

#### Fine-tuning from a multi-task pre-trained model

Additionally, within the PyTorch implementation and leveraging the flexibility offered by the framework and the multi-task training capabilities provided by DPA2,
we also support more general multitask pre-trained models, which includes multiple datasets for pre-training. These pre-training datasets share a common descriptor while maintaining their individual fitting nets,
as detailed in the DPA2 [paper](https://arxiv.org/abs/2312.15492).

For fine-tuning using this multitask pre-trained model (`multitask_pretrained.pt`),
one can select a specific branch (e.g., `CHOOSEN_BRANCH`) included in `multitask_pretrained.pt` for fine-tuning with the following command:

```bash
$ dp --pt train input.json --finetune multitask_pretrained.pt --model-branch CHOOSEN_BRANCH
```

:::{note}
To check the available model branches, you can typically refer to the documentation of the pre-trained model.
If you're still unsure about the available branches, you can try inputting an arbitrary branch name.
This will prompt an error message that displays a list of all the available model branches.

Please note that this feature will be improved in the upcoming version to provide a more user-friendly experience.
:::

This command will start fine-tuning based on the pre-trained model's descriptor and the selected branch's fitting net.
If --model-branch is not set, a randomly initialized fitting net will be used.

### Multi-task fine-tuning

In typical scenarios, relying solely on single-task fine-tuning might gradually lead to the forgetting of information from the pre-trained datasets.
In more advanced scenarios, it is desirable for the model to explicitly retain information from the pre-trained data during fine-tuning to prevent forgetting,
which could be more beneficial for fine-tuning.

To achieve this, it is first necessary to clearly identify the datasets from which the pre-trained model originates and to download the corresponding datasets
that need to be retained for subsequent multitask fine-tuning.
Then, prepare a suitable input script for multitask fine-tuning `multi_input.json` as the following steps.

- Suppose the new dataset for fine-tuning is named `DOWNSTREAM_DATA`, and the datasets to be retained from multitask pre-trained model are `PRE_DATA1` and `PRE_DATA2`. One can:

1. Refer to the [`multi-task-training`](./multi-task-training-pt.md) document to prepare a multitask training script for two systems,
ideally extracting parts (i.e. {ref}`model_dict <model/model_dict>`, {ref}`loss_dict <loss_dict>`, {ref}`data_dict <training/data_dict>` and {ref}`model_prob <training/model_prob>` parts) corresponding to `PRE_DATA1` and `PRE_DATA2` directly from the training script of the pre-trained model.
2. For `DOWNSTREAM_DATA`, select a desired branch to fine-tune from (e.g., `PRE_DATA1`), copy the configurations of `PRE_DATA1` as the configuration for `DOWNSTREAM_DATA` and insert the corresponding data path into the {ref}`data_dict <training/data_dict>`,
thereby generating a three-system multitask training script.
3. In the {ref}`model_dict <model/model_dict>` for `DOWNSTREAM_DATA`, specify the branch from which `DOWNSTREAM_DATA` is to fine-tune using:
`"finetune_head": "PRE_DATA1"`.

The complete `multi_input.json` should appear as follows ("..." means copied from input script of pre-trained model):

```json
"model": {
"shared_dict": {
...
},
"model_dict": {
"PRE_DATA1": {
"type_map": ...,
"descriptor": ...,
"fitting_net": ...
},
"PRE_DATA2": {
"type_map": ...,
"descriptor": ...,
"fitting_net": ...
},
"DOWNSTREAM_DATA": {
"finetune_head": "PRE_DATA1",
"type_map": ...,
"descriptor": ...,
"fitting_net": ...
},
}
},
"learning_rate": ...,
"loss_dict": {
"PRE_DATA1": ...,
"PRE_DATA2": ...,
"DOWNSTREAM_DATA": ...
},
"training": {
"model_prob": {
"PRE_DATA1": 0.5,
"PRE_DATA2": 0.5,
"DOWNSTREAM_DATA": 1.0
},
"data_dict": {
"PRE_DATA1": ...,
"PRE_DATA2": ...,
"DOWNSTREAM_DATA": {
"training_data": "training_data_config_for_DOWNSTREAM_DATA",
"validation_data": "validation_data_config_for_DOWNSTREAM_DATA"
}
},
...
}
```

Subsequently, run the command:

```bash
dp --pt train multi_input.json --finetune multitask_pretrained.pt
```

This will initiate multitask fine-tuning, where for branches `PRE_DATA1` and `PRE_DATA2`,
it is akin to continuing training in `init-model` mode, whereas for `DOWNSTREAM_DATA`,
fine-tuning will be based on the fitting net from `PRE_DATA1`.
You can set `model_prob` for each dataset just the same as that in normal multitask training.
10 changes: 2 additions & 8 deletions doc/train/multi-task-training-pt.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,7 @@ An example input for multi-task training two models in water system is shown as
:linenos:
```

## Finetune from the pretrained multi-task model
## Finetune from the pre-trained multi-task model

To finetune based on the checkpoint `model.pt` after the multi-task pre-training is completed,
users only need to prepare the normal input for single-task training `input_single.json`,
and then select one of the trained model's task names `model_key`.
Run the following command:

```bash
$ dp --pt train input_single.json --finetune model.pt --model-branch model_key
```
users can refer to [this section](./finetuning.md#fine-tuning-from-a-multi-task-pre-trained-model).
2 changes: 1 addition & 1 deletion doc/train/multi-task-training-tf.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ The supported fitting nets for multi-task mode are listed:

The output of `dp freeze` command in multi-task mode can be seen in [freeze command](../freeze/freeze.md).

## Initialization from pretrained multi-task model
## Initialization from pre-trained multi-task model

For advance training in multi-task mode, one can first train the descriptor on several upstream datasets and then transfer it on new downstream ones with newly added fitting nets.
At the second step, you can also inherit some fitting nets trained on upstream datasets, by merely adding fitting net keys in {ref}`fitting_net_dict <model/fitting_net_dict>` and
Expand Down

0 comments on commit 0ec6719

Please sign in to comment.