Medical image segmentation involves identifying and separating object instances in a medical image to delineate various tissues and structures, a task complicated by the significant variations in size, shape, and density of these features. Convolutional neural networks (CNNs) have traditionally been used for this task but have limitations in capturing long-range dependencies. Transformers, equipped with self-attention mechanisms, aim to address this problem. However, in medical image segmentation it is beneficial to merge both local and global features to effectively integrate feature maps across various scales, capturing both detailed features and broader semantic elements for dealing with variations in structures. In this paper, we introduce MSA2Net, a new deep segmentation framework featuring an expedient design of skip-connections. These connections facilitate feature fusion by dynamically weighting and combining coarse-grained encoder features with fine-grained decoder feature maps. Specifically, we propose a Multi-Scale Adaptive Spatial Attention Gate (MASAG), which dynamically adjusts the receptive field (Local and Global contextual information) to ensure that spatially relevant features are selectively highlighted while minimizing background distractions. Extensive evaluations involving dermatology, and radiological datasets demonstrate that our MSA2Net outperforms state-of-the-art (SOTA) works or matches their performance.
14.10.2024
| Accepted as Oral Presentation 🎉20.07.2024
| Accepted in BMVC 2024! 🥳
@article{kolahi2024msa2net,
title={MSA2Net: Multi-scale Adaptive Attention-guided Network for Medical Image Segmentation},
author={Kolahi, Sina Ghorbani and Chaharsooghi, Seyed Kamal and Khatibi, Toktam and Bozorgpour, Afshin and Azad, Reza and Heidari, Moein and Hacihaliloglu, Ilker and Merhof, Dorit},
journal={arXiv preprint arXiv:2407.21640},
year={2024}
}
- Ubuntu 16.04 or higher
- CUDA 11.1 or higher
- Python v3.7 or higher
- Pytorch v1.7 or higher
- Hardware Spec
- A single GPU with 12GB memory or larger capacity (we used RTX 3090)
einops
h5py
imgaug
fvcore
MedPy
numpy
opencv_python
pandas
PyWavelets
scipy
SimpleITK
tensorboardX
timm
torch
torchvision
tqdm
You can download the pretrained and learned weights in the following.
Dataset | Model | download link |
---|---|---|
ImageNet | MaxViT small 224 | Download |
Synapse | MSA2Net | Download |
-
Download the Synapse dataset from here.
-
Download the MaxViT small 224x224 pretrained weights here and then put it in the 'networks/merit_lib/networks.py' file for initialization.
-
Run the following code to install the Requirements.
pip install -r requirements.txt
-
Run the below code to train the MSA2Net on the synapse dataset.
python train.py --root_path ./data/Synapse/train_npz --test_path ./data/Synapse/test_vol_h5 --batch_size 20 --eval_interval 20 --max_epochs 700
--root_path [Train data path]
--test_path [Test data path]
--eval_interval [Evaluation epoch]
-
Run the below code to test the MSA2Net on the synapse dataset.
python test.py --volume_path ./data/Synapse/ --output_dir ./model_out
--volume_path [Root dir of the test data]
--output_dir [Directory of your learned weights]