This repository is the official implementation of the paper Multi-Domain Balanced Sampling Improves Out-of-Distribution Generalization of Chest X-ray Pathology Prediction Models accepted at Medical Imaging meets NeurIPS 2021 (Med-NeurIPS 2021).
The arXiv version can be found here.
We have added the script
to allow users run the experiment on a CPU with a smaller model size. For now, only ResNet-50 is supported.
One can train a quantized model by running:
python --dataset_dir <data dir> --train_datas cx pc --val_data mc --seed 0
Test the model by running:
python --dataset_dir <data dir> --train_datas cx pc --val_data mc --seed 0 --test_only --test_data nih
There are 12 different training, validation and test settings generated by combining 4 different Chest X-ray datasets (NIH ChestX-ray8 dataset, PadChest dataset, CheXpert, and MIMIC-CXR).
The dataset names are condensed as short strings: "nih"
= NIH ChestX-ray8 dataset, "pc"
= PadChest dataset, "cx"
= CheXpert, and "mc"
The models used in this experiment are: DenseNet-121 and ResNet-50.
For each experiment, we compute the ROC-AUC for the following chest x-ray pathologies (labels): Cardiomegaly, Effusion, Edema, Consolidation.
For each experiment, train on two (2) datasets, validate on one (1) leave-out set and test on the remaining one (1) leave-out set.
The file contains code to run the models in this study.
To train a DenseNet-121 model with the Baseline approach, on CheXpert and PadChest, and validate on the MIMIC-CXR dataset, with seed=0 run the following code:
python --merge_train --arch densenet121 --train_datas cx pc --val_data mc --seed 0
To train a DenseNet-121 model with the Balanced Mini-Batch Sampling strategy on CheXpert and PadChest, and validate on the MIMIC-CXR dataset, with seed=0 run the following code:
python --arch densenet121 --train_datas cx pc --val_data mc --seed 0
To run inference, add the arguments --test_only
and test_data
to the same code you pass for training.
Example: running inference on NIH dataset
python --arch densenet121 --train_datas cx pc --val_data mc --seed 0 --test_only --test_data nih
To perform inference using the DenseNet model with pretrained weights from torchxrayvision, run the following line of code:
python --test_data pc --seed 0
Note that you can pass any of the arguments pc
, mc
, cx
or nih
to --test_data
to run inference on PadChest, MIMIC-CXR, CheXpert and ChestX-Ray8 respectively.
In your terminal run pip install scikit-learn wandb torch torchvision torchxrayvision
git clone
cd OoD_Gen-Chest_Xray
pip install -r requirements.txt
title={Multi-Domain Balanced Sampling Improves Out-of-Distribution Generalization of Chest X-ray Pathology Prediction Models},
author={Tetteh, Enoch and Viviano, Joseph and Bengio, Yoshua and Krueger, David and Cohen, Joseph Paul},
booktitle={Medical Imaging Meets NeurIPS},
publisher = {arXiv},