Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
update doc (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Feb 1, 2021
1 parent 606f0d6 commit d5409bd
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions docs/source/general/finetuning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@ Finetuning

Finetuning (or transfer-learning) is the process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset. All Flash tasks have a pre-trained backbone that was trained on large datasets such as ImageNet, and that allows to decrease training time significantly.

Finetuning process can be splitted into 4 steps:
The finetuning process can be split into 4 steps:

1. Train a source neural network model on a source dataset. For computer vision, it is traditionally the [ImageNet dataset](http://www.image-net.org/search?q=cat). As training is costly, library such as [Torchvion](https://pytorch.org/docs/stable/torchvision/index.html) library supports popular pre-trainer model architectures.
1. Train a particular neural network model on a particular dataset. For computer vision, the [ImageNet dataset](http://www.image-net.org/search?q=cat) is widely used for pre-training model. As training is costly, libraries such as [torchvision](https://pytorch.org/docs/stable/torchvision/index.html) provide popular pre-trained model architectures. These are called backbones.

2. Create a new neural network called the target model. Its architecture replicates the source model and parameters, expect the latest layer which is removed. This model without its latest layer is traditionally called a backbone
2. Create a new neural network called the target model. Its architecture replicates the backbone (model from previous step) and parameters, except the latest layer which is usually replaced to fit the necessities of your data.

3. Add new layers after the backbone where the latest output size is the number of target dataset categories. Those new layers, traditionally called head will be randomly initialized while backbone will conserve its pre-trained weights from ImageNet.
3. This new layer (or layers) at the end of the backbone are used to match the backbone output to the number of target categories in your data. They are commonly referred to as the head'. The head is randomly initialized whereas the backbone conserves its pre-trained weights (for example the weights from ImageNet).

4. Train the target model on a smaller target dataset. However, as new layers are randomly initialized, the first gradients will be random when training starts and will destabilize the backbone pre-trained parameters. Therefore, it is good pratice to freeze the backbone, which means the parameters of the backbone won't be trainable for some epochs. After some epochs, the backbone are being unfreezed, meaning the weights will be trainable.
4. Train the target model on a smaller target dataset. However, as the head (new layers) is untrained, the first results (gradients) will be random when training starts and could decrease the backbone performance (by changing its pre-trained parameters). Therefore, it is a good practice to "freeze" the backbone. This means the parameters of the backbone won't be updated until they are "unfrozen" a few epochs later.


.. tip:: If you have a huge dataset and prefer to train from scratch, see the training guide.
.. tip:: If you have a large dataset and prefer to train from scratch, see the training guide.

You can finetune any Flash tasks on your own data in just a 3 simple steps:

Expand All @@ -32,7 +32,8 @@ Finetune options
Flash provides a very simple interface for finetuning through `trainer.finetune` with its `strategy` parameters.

Flash finetune `strategy` argument can either a string or an instance of :class:`~python_lightning.callbacks.finetuning.BaseFinetuning`.
Furthermore, Flash supports 4 builts-in Finetuning Callback accessible via those strings:

Flash supports 4 builts-in Finetuning Callback accessible via those strings:

* `no_freeze`: Don't freeze anything.
* `freeze`: The parameters of the backbone won't be trainable after training starts.
Expand Down

0 comments on commit d5409bd

Please sign in to comment.