Skip to content

[IEEE TPAMI22] MobileSal: Extremely Efficient RGB-D Salient Object Detection [PyTorch & Jittor]

Notifications You must be signed in to change notification settings

yuhuan-wu/MobileSal

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MobileSal

IEEE TPAMI 2021: MobileSal: Extremely Efficient RGB-D Salient Object Detection

This repository contains full training & testing code, and pretrained saliency maps. We have achieved competitive performance on the RGB-D salient object detection task with a speed of 450fps.

If you run into any problems or feel any difficulties to run this code, do not hesitate to leave issues in this repository.

My e-mail is: wuyuhuan @ mail.nankai (dot) edu.cn

[PDF]

This repository contains:

  • Full code, data, pretrained models for training and testing
  • MobileSal deployment, achieving 420FPS (fp32) and 800FPS (fp16) with batch size 1 on a single RTX 2080Ti.

Requirements

PyTorch

  • Python 3.6+
  • PyTorch >=0.4.1, OpenCV-Python
  • Tested on PyTorch 1.7.1

Jittor

  • Python 3.7+
  • Jittor, OpenCV-Python
  • Tested on Jittor 1.3.1

Deployment

For Jittor users, we create a branch jittor. So please run the following command first:

git checkout jittor

To install MobileSal, please run:

pip install -r envs/requirements.txt

Data Preparing

Before training/testing our network, please download the training data:

Note: if you are blocked by Google and Baidu services, you can contact me via e-mail and I will send you a copy of data and model weights.

We have processed the data well so you can use them without any preprocessing steps. After completion of downloading, extract the data and put them to ./data/ folder. Then, the ./datasets/ folder should contain six folders: NJU2K/, NLPR/, STERE/, SSD/, SIP/, DUT-RGBD/, representing NJU2K, NLPR, STEREO, SSD, SIP, DUTLF-D datasets, respectively.

Train

It is very simple to train our network. We have prepared a script to run the training step:

bash ./tools/train.sh

Test

Pretrained Models

As in our paper, we train our model on the NJU2K_NLPR training set, and test our model on NJU2K_test, NLPR_test, STEREO, SIP, and SSD datasets. For DUTLF-D, we train our model on DUTLF-D training set and evaluate on its testing test.

(Default) Trained on NJU2K_NLPR training set:

(Custom) Training on DUTLF-D training set:

Download them and put them into the pretrained/ folder.

Generate Saliency Maps

After preparing the pretrained models, it is also very simple to generate saliency maps via MobileSal:

bash ./tools/test.sh

The scripts will automatically generate saliency maps on the maps/ directory.

Deployment

The deployment largely speeds up MobileSal with batch size of 1. An example script is located at: tools/test_trt.sh. Run:

bash ./tools/test_trt.sh

This script will automatically convert PyTorch MobileSal to TensorRT-based MobileSal. Then it will generate saliency maps via the TensorRT-based MobileSal. On deployment for real-world applications, you can load the converted TensorRT MobileSal for inference:

from torch2trt import torch2trt, TRTModule
model = TRTModule(); trt_model_path = "pretrained/mobilesal_trt.pth"
model.load_state_dict(torch.load(trt_model_path))
result = model(image, depth) # get result with [torch.Tensor] input

Speed Test

We provide a speed test script on MobileSal:

python speed_test.py

The speed result on a single RTX 2080Ti is as below:

Type Input Size Batch Size FP16 FPS
PyTorch 320 x 320 20 No 450
TensorRT 320 x 320 1 No 420
TensorRT 320 x 320 1 Yes 800

Pretrained Saliency maps

For covenience, we provide the pretrained saliency maps on several datasets as below:

Others

TODO

  1. Release the pretrained models and saliency maps on COME15K dataset.
  2. Add results with the P2T transformer backbone.

Contact

  • I encourage everyone to contact me via my e-mail. My e-mail is: wuyuhuan @ mail.nankai (dot) edu.cn

License

The code is released under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License for NonCommercial use only.

Citation

If you are using the code/model/data provided here in a publication, please consider citing our work:

@ARTICLE{wu2021mobilesal,
  author={Wu, Yu-Huan and Liu, Yun and Xu, Jun and Bian, Jia-Wang and Gu, Yu-Chao and Cheng, Ming-Ming},
  journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 
  title={MobileSal: Extremely Efficient RGB-D Salient Object Detection}, 
  year={2022},
  volume={44},
  number={12},
  pages={10261--10269},
  doi={10.1109/TPAMI.2021.3134684}
}

Acknowlogdement

This repository is built under the help of the following five projects for academic use only:

About

[IEEE TPAMI22] MobileSal: Extremely Efficient RGB-D Salient Object Detection [PyTorch & Jittor]

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published