Skip to content

Latest commit

 

History

History
50 lines (31 loc) · 3.35 KB

README.md

File metadata and controls

50 lines (31 loc) · 3.35 KB

Revisiting Hidden Representations in Transfer Learning for Medical Imaging

This repository contains the code and results included in the paper.

Transfer learning has become an increasingly popular approach in medical imaging, as it offers a solution to the challenge of training models with limited dataset sizes. Despite its widespread use, the precise effects of transfer learning on medical image classification are still heavily understudied. We set out to investigate this with a series of systematic experiments on the difference of representations learned from natural (ImageNet) and medical (RadImageNet) source datasets on a range of (seven) medical targets.

Project structure

This project consists of two parts:

  • Fine-tuning natural and medical image sources on medical image classification targets, and
  • Experiments on model similarity.

method overview

1. Fine-tuning

We use publicly available pre-trained ImageNet (Keras implementation of ResNet50) and RadImageNet (https://drive.google.com/drive/folders/1Es7cK1hv7zNHJoUW0tI0e6nLFVYTqPqK?usp=sharing) weights as source tasks in our transfer learning experiments.

We investigate transferability to seven medical target datasets:

A representative image from each dataset can be seen here: data

Usage

Our fine-tuning experiments and models were logged on a private server using MLflow. Update logging in fine-tuning.py.

Specific paths to the different datasets are to be set in io_fun/data_paths.py. Data is split to folds in make_dataframe.py. To fine-tune pre-trained RadImageNet weights on eg. Chest X-rays run:

python src/fine-tuning.py --base RadImageNet --t chest --image_h 112 --image_w 112 --freeze False --e 200 --k 1 --batch 128 --l 0.00001

To first freeze the pre-trained weights and then fine-tune after training the classification layer set --freeze to True.

2. Model similarity

Model similarity is assesed by comparing the network activations over a sample of images from the target datasets using two similarity measures, Canonical Correlation Analysis (CCA) and prediction similarity. We use publibly available CCA implementation which should be placed at the same level locally as src/. Model similarity can be evaluated using CCA.py and prediction_similarity.py after placing fine-tuned models locally in a directory next to src/. Figures in the paper can be reproduced using CCA_plot.py, filter_plot.py, similarity_plot.py, and simVSauc_plot.py.

Prerequisites

The packages needed to run the fine-tuning experiments are listed in the conda.yaml file.