Multi-Granularity Cross-modal Alignment for Generalized Medical Visual Representation Learning, NeurIPS 2022.
To clone this repository:
git clone https://github.com/fuying-wang/MGCA.git
To install Python dependencies:
pip install -r requirements.txt
To install package mgca
:
pip install -e .
Datasets we used are as follows:
-
MIMIC-CXR: We downloaded the MIMIC-CXR-JPG dataset as the radiographs. Paired medical reports can be downloaded in MIMIC-CXR.
-
CheXpert: We downloaded the CheXpert dataset which consisting of 224,316 chest radiographs of 65,240 patients.
-
RSNA: We used the stage 2 of RSNA dataset in Kaggle.
-
COVIDx: We used the version 6 of COVIDx dataset in Kaggle.
-
SIIM: We downloaded the stage 1 of SIIM dataset in Kaggle.
-
Object-CXR: We downloaded the object-CXR dataset in its official website.
After downloading datasets, please check if the path in mgca/constants.py
is correct.
We preprocessed these datasets and split the dataset into train/val/test set using the code in mgca/preprocess
.
Reminder: Please check Line 47 of mgca/datasets/pretrain_dataset.py
, Line 503 of mgca_module.py
and make sure the path is correct.
We pre-trained MGCA on MIMIC-CXR using this command:
cd mgca/models/mgca
CUDA_VISIBLE_DEVICES=0,1 python mgca_module.py --gpus 2 --strategy ddp
We train our framework 50 epochs on 2 pieces of RTX 3090 GPUs with batch size of 144. It takes about 1 day to pre-train this model.
Note that it is flexible to develop other pre-training models under this framework. You may create a folder in mgca/models
and complete the {MODEL_NAME}_module.py
file.
Pre-trained models can be found here.
We evlauate the performance of MGCA framework on three downstream tasks: classification, object detection and semantic segmentation. Before finetuning, we need set the path
(or ckpt_path
) argument to the path of pre-trained MGCA model.
We evaluate linear classification performance of our model using this command:
cd mgca/models/mgca
CUDA_VISIBLE_DEVICES=1 python mgca_finetuner.py --gpus 1 --dataset chexpert --data_pct 0.01
We can use --dataset
to set specific dataset for finetuning. Here, 3 datsets are available: chexpert, rsna and covidx.
We can use --data_pct
to set the fraction of training data for finetuning.
We evaluate object detection performance of our model using this command:
cd mgca/models/mgca
CUDA_VISIBLE_DEVICES=0 python mgca_detector.py --devices 1 --dataset rsna --data_pct 1 --learning_rate 5e-4
Here, 2 datsets are available: rsna and object_cxr.
To run all experiments for this detection task:
sh run_det_funetune.sh
We evaluate semantic segmentation performance of our model using this command:
cd mgca/models/mgca
CUDA_VISIBLE_DEVICES=0 python mgca_segmenter.py --gpus 1 --data_pct 1 --dataset rsna --batch_size 16 --learning_rate 5e-4
Here, 2 datsets are available: rsna and siim.
To run all experiments for this detection task:
sh run_seg_funetune.sh
- Refactor the train/valid/test lists.
If you found our work useful in your research, please consider citing our works(s) at:
@article{wang2022multi,
title={Multi-granularity cross-modal alignment for generalized medical visual representation learning},
author={Wang, Fuying and Zhou, Yuyin and Wang, Shujun and Vardhanabhuti, Varut and Yu, Lequan},
journal={Advances in Neural Information Processing Systems},
volume={35},
pages={33536--33549},
year={2022}
}