One PyTorch based Domain Generalization (DG) framework for Medical Image Segmentation, including the data preprocessing, data augmentation, varies network implementations, different training methods and standard performance evaluation.
Why this framework? Well, unlike the natural image tasks, the medical image segmentation performance for domain generalization might be very different when using different data preprocessing, data augmentation or different training methods. Many previous works lack fair comparasion and results can hardly be repeated by community. We want to achieve one fair evaluation for varies domain generalization methods.
Multidomain medical image segmentation tasks are limited. We collect the high quality datasets under multi-center setting in previous research as much as possible. We reorganize the dataset into the standard format for preprocessing and training. Currently supported dataset are listed as follow:
- Multi-domain Prostate MRI Segmentation Dataset
- Multi-domain Fundus Optic Cup and Optic Disc Segmentation Dataset
- Multi-domain Fundus Vessel Segmentation Dataset
- Multi-domain Covid-19 Segmentation Dataset
- Multi-domain Cardiac Segmentation Dataset
We highly encourage if anyone wants to combine their published dataset into the framework!
Standard data preprocessing is the basis for any fair comparation. In this work, we adpot the widely used data preprocessing strategy which supports varies data format, including but not limited to 2D: PNG, JPG... 3D: NIFTI, DICOM... We transfer all the raw data to numpy array according to different settings like 2D network for 2D data, 2D network for 3D data, 3D network for 3D data. More detailed tutorials are available in []
Different data augemnetation methods can influence the trained model's generalization ability dramatically, as shown in \BigAug. We standardalize the data augmentation based on the batch-generator. For every baseline, we basically adopt the data augmentation of the default setting of nnUNet. More detailed tutorials are available in []
We support varies network structure and, actually, you can merge any segmentation enginee as you like. In this work, many network implementations are based on MONAI Network. One thing you need to notice is that for some domain generalization methods, you need to return the imtermidiate feature for training. Make sure your network structure supports this. More detailed tutorials are available in []
Varies optimizers including, SGD, Adam, AdamW, AMSGrad, RMSProp are supported.
Different learning rate schedulers including Cosine, Multi-Step, Single-Step are supported.
Segmentation Losses including Dice, DiceCE, DiceFocal are also supported.
The standard evaluation methods are essential for domain generalization ability measurement. We provide automatic evaluation using varies metrics including "Dice", "Jaccard", "Precision", "Recall", "Accuracy", "Hausdorff Distance 95", "Avg. Symmetric Surface Distance". Note that for fair comparasion, evaluation is based on case-level. For example, one complete 3D case rather than the slices of 3D is considered as one sample.
Clone the repo, create a conda environment and install the package
# Clone the repo
git clone https://github.com/freshman97/MedSeg_Generalization_Framework.git
cd MedSeg_Generalization_Framework
# Create the environment
conda create -y -n medsegdg python=3.10
conda activate medsegdgs
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
# Install the required packages
pip install -r requirements.txt
python setup.py install
One basic training example using the Prostate MRI dataset under domain generalization setting, more detailed tutorials for config setting are available in [].
#!/bin/bash
DATA=path_to_dir
DATASET=ProstateMRI
D1=BIDMC
D2=BMC
D3=HK
D4=I2CVB
D5=UCL
D6=RUNMC
SEED=0
method=baseline
cuda_device=0
# train with 5 domain and test on the rest
CUDA_VISIBLE_DEVICES=${cuda_device} python MedSegDGSSL/tools/train.py \
--root ${DATA} \
--trainer Vanilla \
--source-domains ${D1} ${D2} ${D3} ${D4} ${D5} \
--target-domains ${D6} \
--seed ${SEED} \
--config-file configs/trainers/${DATASET}.yaml \
--output-dir output/dg/${DATASET}/${method}/${D6}
Here we implement several previous domain generalization methods as follow. Detailed methods can be check under /engine/dg
Data Augmentation
Style Mixing
Feature Alignment
Meta-Learning
Self-Challenging
We can add the new segmentation network easily (we use the unet implementation of MONAI as example) as follow:
import monai
from monai.networks.nets import UNet
from MedSegDGSSL.network.segnet.build import NETWORK_REGISTRY
@NETWORK_REGISTRY.register()
def monaiunet(model_cfg):
unet = UNet(spatial_dims=model_cfg.SPATIAL_DIMS,
in_channels=model_cfg.IN_CHANNELS,
out_channels=model_cfg.OUT_CHANNELS,
channels=model_cfg.FEATURES,
strides=model_cfg.STRIDES,
num_res_units=2,
norm=model_cfg.NORM,
dropout=model_cfg.DROPOUT)
return unet
We can add the new dataset as follow, the dataset discription should includes name, domain, labels, and data type information:
from MedSegDGSSL.dataset.build import DATASET_REGISTRY
from MedSegDGSSL.dataset.data_base import Datum, DatasetBase
@DATASET_REGISTRY.register()
class ProstateMRI(DatasetBase):
"""Prostate Segmentation
Statistics:
- 6 domains: "BMC", "HK", "I2CVB", "UCL", "RUNMC", "BIDMC"
- Prostate Segmentation
"""
dataset_name = 'ProstateMRI'
domains = ["BMC", "HK", "I2CVB", "UCL", "RUNMC", "BIDMC"]
labels = {"0": "Background", "1": "Prostate"}
data_shape = "3D"
def __init__(self, cfg):
super().__init__(data_dir=cfg.DATASET.ROOT,
train_domains=cfg.DATASET.SOURCE_DOMAINS,
test_domains=cfg.DATASET.TARGET_DOMAINS)
self._lab2cname = self.labels
self.num_classes = len(self.labels)
We can add the new trainer easily as follow. Make sure you read the implementation of SimpleTrainer first before adding new trainer to understand where to change.
from MedSegDGSSL.engine import TRAINER_REGISTRY, TrainerX
@TRAINER_REGISTRY.register()
class NewTrainer(TrainerX):
"""Trainer description.
"""
def forward_backward(self, batch):
input, label = self.parse_batch_train(batch)
output = self.model(input)
loss = self.loss_func(output, label)
self.model_backward_and_update(loss)
loss_summary = {
'loss': loss.item()}
@article{zhang2023domain,
title={Domain generalization with adversarial intensity attack for medical image segmentation},
author={Zhang, Zheyuan and Wang, Bin and Yao, Lanhong and Demir, Ugur and Jha, Debesh and Turkbey, Ismail Baris and Gong, Boqing and Bagci, Ulas},
journal={arXiv preprint arXiv:2304.02720},
year={2023}
}
We acknowledge the DASSL, MONAI, and BatchGenerator for their great contribution.